1use core::ops::{Index, IndexMut};
4
5use crate::prelude_dev::*;
6
7pub fn into_slice_f<S, D, I>(tensor: TensorBase<S, D>, index: I) -> Result<TensorBase<S, IxD>>
10where
11 D: DimAPI,
12 I: TryInto<AxesIndex<Indexer>, Error = Error>,
13{
14 let (data, layout) = tensor.into_raw_parts();
15 let index = index.try_into()?;
16 let layout = layout.dim_slice(index.as_ref())?;
17 return unsafe { Ok(TensorBase::new_unchecked(data, layout)) };
18}
19
20pub fn into_slice<S, D, I>(tensor: TensorBase<S, D>, index: I) -> TensorBase<S, IxD>
21where
22 D: DimAPI,
23 I: TryInto<AxesIndex<Indexer>, Error = Error>,
24{
25 into_slice_f(tensor, index).unwrap()
26}
27
28pub fn slice_f<R, T, B, D, I>(tensor: &TensorAny<R, T, B, D>, index: I) -> Result<TensorView<'_, T, B, IxD>>
29where
30 D: DimAPI,
31 I: TryInto<AxesIndex<Indexer>, Error = Error>,
32 R: DataAPI<Data = B::Raw>,
33 B: DeviceAPI<T>,
34{
35 into_slice_f(tensor.view(), index)
36}
37
38pub fn slice<R, T, B, D, I>(tensor: &TensorAny<R, T, B, D>, index: I) -> TensorView<'_, T, B, IxD>
39where
40 D: DimAPI,
41 I: TryInto<AxesIndex<Indexer>, Error = Error>,
42 R: DataAPI<Data = B::Raw>,
43 B: DeviceAPI<T>,
44{
45 slice_f(tensor, index).unwrap()
46}
47
48impl<R, T, B, D> TensorAny<R, T, B, D>
49where
50 R: DataAPI<Data = B::Raw>,
51 B: DeviceAPI<T>,
52 D: DimAPI,
53{
54 pub fn into_slice_f<I>(self, index: I) -> Result<TensorAny<R, T, B, IxD>>
55 where
56 I: TryInto<AxesIndex<Indexer>, Error = Error>,
57 {
58 into_slice_f(self, index)
59 }
60
61 pub fn into_slice<I>(self, index: I) -> TensorAny<R, T, B, IxD>
62 where
63 I: TryInto<AxesIndex<Indexer>, Error = Error>,
64 {
65 into_slice(self, index)
66 }
67
68 pub fn slice_f<I>(&self, index: I) -> Result<TensorView<'_, T, B, IxD>>
69 where
70 I: TryInto<AxesIndex<Indexer>, Error = Error>,
71 {
72 slice_f(self, index)
73 }
74
75 pub fn slice<I>(&self, index: I) -> TensorView<'_, T, B, IxD>
76 where
77 I: TryInto<AxesIndex<Indexer>, Error = Error>,
78 {
79 slice(self, index)
80 }
81
82 pub fn i_f<I>(&self, index: I) -> Result<TensorView<'_, T, B, IxD>>
83 where
84 I: TryInto<AxesIndex<Indexer>, Error = Error>,
85 {
86 slice_f(self, index)
87 }
88
89 pub fn i<I>(&self, index: I) -> TensorView<'_, T, B, IxD>
90 where
91 I: TryInto<AxesIndex<Indexer>, Error = Error>,
92 {
93 slice(self, index)
94 }
95}
96
97pub fn slice_mut_f<R, T, B, D, I>(tensor: &mut TensorAny<R, T, B, D>, index: I) -> Result<TensorMut<'_, T, B, IxD>>
102where
103 D: DimAPI,
104 I: TryInto<AxesIndex<Indexer>, Error = Error>,
105 R: DataMutAPI<Data = B::Raw>,
106 B: DeviceAPI<T>,
107{
108 into_slice_f(tensor.view_mut(), index)
109}
110
111pub fn slice_mut<R, T, B, D, I>(tensor: &mut TensorAny<R, T, B, D>, index: I) -> TensorMut<'_, T, B, IxD>
112where
113 D: DimAPI,
114 I: TryInto<AxesIndex<Indexer>, Error = Error>,
115 R: DataMutAPI<Data = B::Raw>,
116 B: DeviceAPI<T>,
117{
118 slice_mut_f(tensor, index).unwrap()
119}
120
121impl<R, T, B, D> TensorAny<R, T, B, D>
122where
123 R: DataMutAPI<Data = B::Raw>,
124 B: DeviceAPI<T>,
125 D: DimAPI,
126{
127 pub fn slice_mut_f<I>(&mut self, index: I) -> Result<TensorMut<'_, T, B, IxD>>
128 where
129 I: TryInto<AxesIndex<Indexer>, Error = Error>,
130 {
131 slice_mut_f(self, index)
132 }
133
134 pub fn slice_mut<I>(&mut self, index: I) -> TensorMut<'_, T, B, IxD>
135 where
136 I: TryInto<AxesIndex<Indexer>, Error = Error>,
137 {
138 slice_mut(self, index)
139 }
140
141 pub fn i_mut_f<I>(&mut self, index: I) -> Result<TensorMut<'_, T, B, IxD>>
142 where
143 I: TryInto<AxesIndex<Indexer>, Error = Error>,
144 {
145 slice_mut_f(self, index)
146 }
147
148 pub fn i_mut<I>(&mut self, index: I) -> TensorMut<'_, T, B, IxD>
149 where
150 I: TryInto<AxesIndex<Indexer>, Error = Error>,
151 {
152 slice_mut(self, index)
153 }
154}
155
156pub struct DiagonalArgs {
161 pub offset: Option<isize>,
162 pub axis1: Option<isize>,
163 pub axis2: Option<isize>,
164}
165
166#[duplicate_item(
167 S0 S1 S2;
168 [isize] [isize] [isize];
169 [usize] [isize] [isize];
170 [usize] [usize] [usize];
171 [i32 ] [i32 ] [i32 ];
172 [i64 ] [i64 ] [i64 ];
173)]
174#[allow(clippy::unnecessary_cast)]
175impl From<(S0, S1, S2)> for DiagonalArgs {
176 fn from(args: (S0, S1, S2)) -> Self {
177 let (offset, axis1, axis2) = args;
178 Self { offset: Some(offset as isize), axis1: Some(axis1 as isize), axis2: Some(axis2 as isize) }
179 }
180}
181
182#[duplicate_item(S; [isize]; [usize]; [i32]; [i64];)]
183#[allow(clippy::unnecessary_cast)]
184impl From<S> for DiagonalArgs {
185 fn from(offset: S) -> Self {
186 Self { offset: Some(offset as isize), axis1: None, axis2: None }
187 }
188}
189
190impl From<()> for DiagonalArgs {
191 fn from(_: ()) -> Self {
192 Self { offset: None, axis1: None, axis2: None }
193 }
194}
195
196impl From<Option<isize>> for DiagonalArgs {
197 fn from(offset: Option<isize>) -> Self {
198 Self { offset, axis1: None, axis2: None }
199 }
200}
201
202pub fn into_diagonal_f<S, D>(
203 tensor: TensorBase<S, D>,
204 diagonal_args: impl Into<DiagonalArgs>,
205) -> Result<TensorBase<S, D::SmallerOne>>
206where
207 D: DimAPI + DimSmallerOneAPI,
208 D::SmallerOne: DimAPI,
209{
210 let (data, layout) = tensor.into_raw_parts();
211 let DiagonalArgs { offset, axis1, axis2 } = diagonal_args.into();
212 let layout = layout.diagonal(offset, axis1, axis2)?;
213 return unsafe { Ok(TensorBase::new_unchecked(data, layout)) };
214}
215
216pub fn into_diagonal<S, D>(
217 tensor: TensorBase<S, D>,
218 diagonal_args: impl Into<DiagonalArgs>,
219) -> TensorBase<S, D::SmallerOne>
220where
221 D: DimAPI + DimSmallerOneAPI,
222 D::SmallerOne: DimAPI,
223{
224 into_diagonal_f(tensor, diagonal_args).unwrap()
225}
226
227pub fn diagonal_f<R, T, B, D>(
228 tensor: &TensorAny<R, T, B, D>,
229 diagonal_args: impl Into<DiagonalArgs>,
230) -> Result<TensorView<'_, T, B, D::SmallerOne>>
231where
232 D: DimAPI + DimSmallerOneAPI,
233 D::SmallerOne: DimAPI,
234 R: DataAPI<Data = B::Raw>,
235 B: DeviceAPI<T>,
236{
237 into_diagonal_f(tensor.view(), diagonal_args)
238}
239
240pub fn diagonal<R, T, B, D>(
241 tensor: &TensorAny<R, T, B, D>,
242 diagonal_args: impl Into<DiagonalArgs>,
243) -> TensorView<'_, T, B, D::SmallerOne>
244where
245 D: DimAPI + DimSmallerOneAPI,
246 D::SmallerOne: DimAPI,
247 R: DataAPI<Data = B::Raw>,
248 B: DeviceAPI<T>,
249{
250 diagonal_f(tensor, diagonal_args).unwrap()
251}
252
253impl<R, T, B, D> TensorAny<R, T, B, D>
254where
255 R: DataAPI<Data = B::Raw>,
256 B: DeviceAPI<T>,
257 D: DimAPI + DimSmallerOneAPI,
258 D::SmallerOne: DimAPI,
259{
260 pub fn into_diagonal_f(self, diagonal_args: impl Into<DiagonalArgs>) -> Result<TensorAny<R, T, B, D::SmallerOne>> {
261 into_diagonal_f(self, diagonal_args)
262 }
263
264 pub fn into_diagonal(self, diagonal_args: impl Into<DiagonalArgs>) -> TensorAny<R, T, B, D::SmallerOne> {
265 into_diagonal(self, diagonal_args)
266 }
267
268 pub fn diagonal_f(&self, diagonal_args: impl Into<DiagonalArgs>) -> Result<TensorView<'_, T, B, D::SmallerOne>> {
269 diagonal_f(self, diagonal_args)
270 }
271
272 pub fn diagonal(&self, diagonal_args: impl Into<DiagonalArgs>) -> TensorView<'_, T, B, D::SmallerOne> {
273 diagonal(self, diagonal_args)
274 }
275}
276
277pub fn into_diagonal_mut_f<S, D>(
282 tensor: TensorBase<S, D>,
283 diagonal_args: impl Into<DiagonalArgs>,
284) -> Result<TensorBase<S, D::SmallerOne>>
285where
286 D: DimAPI + DimSmallerOneAPI,
287 D::SmallerOne: DimAPI,
288{
289 let (data, layout) = tensor.into_raw_parts();
290 let DiagonalArgs { offset, axis1, axis2 } = diagonal_args.into();
291 let layout = layout.diagonal(offset, axis1, axis2)?;
292 return unsafe { Ok(TensorBase::new_unchecked(data, layout)) };
293}
294
295pub fn into_diagonal_mut<S, D>(
296 tensor: TensorBase<S, D>,
297 diagonal_args: impl Into<DiagonalArgs>,
298) -> TensorBase<S, D::SmallerOne>
299where
300 D: DimAPI + DimSmallerOneAPI,
301 D::SmallerOne: DimAPI,
302{
303 into_diagonal_mut_f(tensor, diagonal_args).unwrap()
304}
305
306pub fn diagonal_mut_f<R, T, B, D>(
307 tensor: &mut TensorAny<R, T, B, D>,
308 diagonal_args: impl Into<DiagonalArgs>,
309) -> Result<TensorMut<'_, T, B, D::SmallerOne>>
310where
311 D: DimAPI + DimSmallerOneAPI,
312 D::SmallerOne: DimAPI,
313 R: DataMutAPI<Data = B::Raw>,
314 B: DeviceAPI<T>,
315{
316 into_diagonal_mut_f(tensor.view_mut(), diagonal_args)
317}
318
319pub fn diagonal_mut<R, T, B, D>(
320 tensor: &mut TensorAny<R, T, B, D>,
321 diagonal_args: impl Into<DiagonalArgs>,
322) -> TensorMut<'_, T, B, D::SmallerOne>
323where
324 D: DimAPI + DimSmallerOneAPI,
325 D::SmallerOne: DimAPI,
326 R: DataMutAPI<Data = B::Raw>,
327 B: DeviceAPI<T>,
328{
329 diagonal_mut_f(tensor, diagonal_args).unwrap()
330}
331
332impl<R, T, B, D> TensorAny<R, T, B, D>
333where
334 R: DataMutAPI<Data = B::Raw>,
335 B: DeviceAPI<T>,
336 D: DimAPI + DimSmallerOneAPI,
337 D::SmallerOne: DimAPI,
338{
339 pub fn into_diagonal_mut_f(
340 self,
341 diagonal_args: impl Into<DiagonalArgs>,
342 ) -> Result<TensorAny<R, T, B, D::SmallerOne>> {
343 into_diagonal_mut_f(self, diagonal_args)
344 }
345
346 pub fn into_diagonal_mut(self, diagonal_args: impl Into<DiagonalArgs>) -> TensorAny<R, T, B, D::SmallerOne> {
347 into_diagonal_mut(self, diagonal_args)
348 }
349
350 pub fn diagonal_mut_f(
351 &mut self,
352 diagonal_args: impl Into<DiagonalArgs>,
353 ) -> Result<TensorMut<'_, T, B, D::SmallerOne>> {
354 diagonal_mut_f(self, diagonal_args)
355 }
356
357 pub fn diagonal_mut(&mut self, diagonal_args: impl Into<DiagonalArgs>) -> TensorMut<'_, T, B, D::SmallerOne> {
358 diagonal_mut(self, diagonal_args)
359 }
360}
361
362#[duplicate_item(
371 TensorStruct;
372 [Tensor<T, B, D>];
373 [TensorView<'_, T, B, D>];
374 [TensorViewMut<'_, T, B, D>];
375 [TensorCow<'_, T, B, D>];
376)]
377impl<T, D, B, I> Index<I> for TensorStruct
378where
379 T: Clone,
380 D: DimAPI,
381 B: DeviceAPI<T, Raw = Vec<T>>,
382 I: AsRef<[usize]>,
383{
384 type Output = T;
385
386 #[inline]
387 fn index(&self, index: I) -> &Self::Output {
388 let index = index.as_ref().iter().map(|&v| v as isize).collect::<Vec<_>>();
389 let i = self.layout().index(index.as_ref());
390 let raw = self.raw();
391 raw.index(i)
392 }
393}
394
395#[duplicate_item(
396 TensorStruct;
397 [Tensor<T, B, D>];
398 [TensorViewMut<'_, T, B, D>];
399)]
400impl<T, D, B, I> IndexMut<I> for TensorStruct
401where
402 T: Clone,
403 D: DimAPI,
404 B: DeviceAPI<T, Raw = Vec<T>>,
405 I: AsRef<[usize]>,
406{
407 #[inline]
408 fn index_mut(&mut self, index: I) -> &mut Self::Output {
409 let index = index.as_ref().iter().map(|&v| v as isize).collect::<Vec<_>>();
410 let i = self.layout().index(index.as_ref());
411 let raw = self.raw_mut();
412 raw.index_mut(i)
413 }
414}
415
416#[duplicate_item(
421 TensorStruct;
422 [Tensor<T, B, D>];
423 [TensorView<'_, T, B, D>];
424 [TensorViewMut<'_, T, B, D>];
425 [TensorCow<'_, T, B, D>];
426)]
427impl<T, B, D> TensorStruct
428where
429 T: Clone,
430 D: DimAPI,
431 B: DeviceAPI<T, Raw = Vec<T>>,
432{
433 #[inline]
438 pub unsafe fn index_uncheck<I>(&self, index: I) -> &T
439 where
440 I: AsRef<[usize]>,
441 {
442 let index = index.as_ref();
443 let i = unsafe { self.layout().index_uncheck(index) } as usize;
444 let raw = self.raw();
445 raw.index(i)
446 }
447}
448
449#[duplicate_item(
450 TensorStruct;
451 [Tensor<T, B, D>];
452 [TensorViewMut<'_, T, B, D>];
453)]
454impl<T, B, D> TensorStruct
455where
456 T: Clone,
457 D: DimAPI,
458 B: DeviceAPI<T, Raw = Vec<T>>,
459{
460 #[inline]
465 pub unsafe fn index_mut_uncheck<I>(&mut self, index: I) -> &mut T
466 where
467 I: AsRef<[usize]>,
468 {
469 let index = index.as_ref();
470 let i = unsafe { self.layout().index_uncheck(index) } as usize;
471 let raw = self.raw_mut();
472 raw.index_mut(i)
473 }
474}
475
476#[cfg(test)]
479mod test {
480 use super::*;
481
482 #[test]
483 fn test_tensor_slice_1d() {
484 let tensor = asarray(vec![1, 2, 3, 4, 5]);
485 let tensor_slice = tensor.slice(s![1..4]);
486 println!("{tensor_slice:?}");
487 let tensor_slice = tensor.slice(s![1..4, None]);
488 println!("{tensor_slice:?}");
489 let tensor_slice = tensor.slice(1);
490 println!("{tensor_slice:?}");
491 let tensor_slice = tensor.slice(slice!(2, 7, 2));
492 println!("{tensor_slice:?}");
493
494 let mut tensor = asarray(vec![1, 2, 3, 4, 5]);
495 let mut tensor_slice = tensor.slice_mut(s![1..4]);
496 tensor_slice += 10;
497 println!("{tensor:?}");
498 *&mut tensor.slice_mut(s![1..4]) += 10;
499 println!("{tensor:?}");
500 }
501
502 #[test]
503 fn test_tensor_nd() {
504 let tensor = arange(24.0).into_shape([2, 3, 4]);
505 let tensor_slice = tensor.slice(s![1..2, 1..3, 1..4]);
506 println!("{tensor_slice:?}");
507 let tensor_slice = tensor.slice(s![1]);
508 println!("{tensor_slice:?}");
509 }
510
511 #[test]
512 fn test_tensor_index() {
513 let mut tensor = asarray(vec![1, 2, 3, 4, 5]);
514 let value = tensor[[1]];
515 println!("{value:?}");
516 let tensor_view = tensor.view();
517 let value = tensor_view[[2]];
518 {
519 let tensor_view = tensor.view();
520 let value = tensor_view[[3]];
521 println!("{value:?}");
522 let mut tensor_view_mut = tensor.view_mut();
523 tensor_view_mut[[4]] += 1;
524 *&mut tensor_view_mut.slice_mut(4) += 1;
525 }
526 println!("{value:?}");
527 println!("{tensor:?}");
528 }
529
530 #[test]
531 fn test_diagonal_compiles() {
532 let a = arange(24.0).into_shape([2, 3, 4]);
533 let b = a.diagonal(1);
535 println!("{b:?}");
536 }
537}