rstsr_core/tensor/
creation.rs

1//! Creation methods for `Tensor` struct.
2//!
3//! This module relates to the [Python array API standard v2023.12](https://data-apis.org/array-api/2023.12/API_specification/creation_functions.html).
4//!
5//! Todo list:
6//! - [x] `arange`: [`arange`]
7//! - [x] `asarray`: [`asarray`] (defined elsewhere)
8//! - [x] `empty`: [`empty`]
9//! - [x] `empty_like`: [`empty_like`]
10//! - [x] `eye`: [`eye`]
11//! - [ ] ~`from_dlpack`~
12//! - [x] `full`: [`full`]
13//! - [x] `full_like`: [`full_like`]
14//! - [x] `linspace`: [`linspace`]
15//! - [x] `meshgrid`
16//! - [x] `ones`: [`ones`]
17//! - [x] `ones_like`: [`ones_like`]
18//! - [x] `tril`
19//! - [x] `triu`
20//! - [x] `zeros`: [`zeros`]
21//! - [x] `zeros_like`: [`zeros_like`]
22
23use crate::prelude_dev::*;
24use num::complex::ComplexFloat;
25use num::Num;
26
27/* #region arange */
28
29pub trait ArangeAPI<Inp> {
30    type Out;
31
32    fn arange_f(self) -> Result<Self::Out>;
33
34    fn arange(self) -> Self::Out
35    where
36        Self: Sized,
37    {
38        Self::arange_f(self).unwrap()
39    }
40}
41
42/// Evenly spaced values within the half-open interval `[start, stop)` as
43/// one-dimensional array.
44///
45/// # See also
46///
47/// - [Python array API standard: `arange`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.arange.html)
48pub fn arange<Args, Inp>(param: Args) -> Args::Out
49where
50    Args: ArangeAPI<Inp>,
51{
52    return ArangeAPI::arange(param);
53}
54
55pub fn arange_f<Args, Inp>(param: Args) -> Result<Args::Out>
56where
57    Args: ArangeAPI<Inp>,
58{
59    return ArangeAPI::arange_f(param);
60}
61
62impl<T, B> ArangeAPI<(T, B)> for (T, T, T, &B)
63where
64    T: Num + PartialOrd,
65    B: DeviceAPI<T> + DeviceCreationPartialOrdNumAPI<T>,
66{
67    type Out = Tensor<T, B, IxD>;
68
69    fn arange_f(self) -> Result<Self::Out> {
70        // full implementation
71        let (start, stop, step, device) = self;
72        let data = device.arange_impl(start, stop, step)?;
73        let layout = vec![data.len()].into();
74        unsafe { Ok(Tensor::new_unchecked(data, layout)) }
75    }
76}
77
78impl<T, B> ArangeAPI<(T, B)> for (T, T, &B)
79where
80    T: Num + PartialOrd,
81    B: DeviceAPI<T> + DeviceCreationPartialOrdNumAPI<T>,
82{
83    type Out = Tensor<T, B, IxD>;
84
85    fn arange_f(self) -> Result<Self::Out> {
86        // (start, stop, device) -> (start, stop, 1, device)
87        let (start, stop, device) = self;
88        let step = T::one();
89        arange_f((start, stop, step, device))
90    }
91}
92
93impl<T, B> ArangeAPI<(T, B)> for (T, &B)
94where
95    T: Num + PartialOrd,
96    B: DeviceAPI<T> + DeviceCreationPartialOrdNumAPI<T>,
97{
98    type Out = Tensor<T, B, IxD>;
99
100    fn arange_f(self) -> Result<Self::Out> {
101        // (stop, device) -> (0, stop, 1, device)
102        let (stop, device) = self;
103        let start = T::zero();
104        let step = T::one();
105        arange_f((start, stop, step, device))
106    }
107}
108
109impl<T> ArangeAPI<T> for (T, T, T)
110where
111    T: Num + PartialOrd + Clone + Send + Sync,
112{
113    type Out = Tensor<T, DeviceCpu, IxD>;
114
115    fn arange_f(self) -> Result<Self::Out> {
116        // full implementation
117        let (start, stop, step) = self;
118        arange_f((start, stop, step, &DeviceCpu::default()))
119    }
120}
121
122impl<T> ArangeAPI<T> for (T, T)
123where
124    T: Num + PartialOrd + Clone,
125{
126    type Out = Tensor<T, DeviceCpu, IxD>;
127
128    fn arange_f(self) -> Result<Self::Out> {
129        // (start, stop) -> (start, stop, 1)
130        let (start, stop) = self;
131        arange_f((start, stop, &DeviceCpu::default()))
132    }
133}
134
135impl<T> ArangeAPI<T> for T
136where
137    T: Num + PartialOrd + Clone,
138{
139    type Out = Tensor<T, DeviceCpu, IxD>;
140
141    fn arange_f(self) -> Result<Self::Out> {
142        // (stop) -> (0, stop, 1)
143        arange_f((T::zero(), self, &DeviceCpu::default()))
144    }
145}
146
147/* #endregion */
148
149/* #region empty */
150
151pub trait EmptyAPI<Inp> {
152    type Out;
153
154    /// # Safety
155    ///
156    /// This function is unsafe because it creates a tensor with uninitialized.
157    unsafe fn empty_f(self) -> Result<Self::Out>;
158
159    /// # Safety
160    ///
161    /// This function is unsafe because it creates a tensor with uninitialized.
162    unsafe fn empty(self) -> Self::Out
163    where
164        Self: Sized,
165    {
166        Self::empty_f(self).unwrap()
167    }
168}
169
170/// Uninitialized tensor having a specified shape.
171///
172/// # Safety
173///
174/// This function is unsafe because it creates a tensor with uninitialized.
175///
176/// # See also
177///
178/// - [Python array API standard: `empty`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.empty.html)
179pub unsafe fn empty<Args, Inp>(param: Args) -> Args::Out
180where
181    Args: EmptyAPI<Inp>,
182{
183    return EmptyAPI::empty(param);
184}
185
186/// # Safety
187///
188/// This function is unsafe because it creates a tensor with uninitialized.
189pub unsafe fn empty_f<Args, Inp>(param: Args) -> Result<Args::Out>
190where
191    Args: EmptyAPI<Inp>,
192{
193    return EmptyAPI::empty_f(param);
194}
195
196impl<T, D, B> EmptyAPI<(T, D)> for (Layout<D>, &B)
197where
198    D: DimAPI,
199    B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
200{
201    type Out = Tensor<T, B, IxD>;
202
203    unsafe fn empty_f(self) -> Result<Self::Out> {
204        let (layout, device) = self;
205        let (_, idx_max) = layout.bounds_index()?;
206        let storage = B::empty_impl(device, idx_max)?;
207        unsafe { Ok(Tensor::new_unchecked(storage, layout.into_dim()?)) }
208    }
209}
210
211impl<T, D, B> EmptyAPI<(T, D)> for (D, FlagOrder, &B)
212where
213    D: DimAPI,
214    B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
215{
216    type Out = Tensor<T, B, IxD>;
217
218    unsafe fn empty_f(self) -> Result<Self::Out> {
219        let (shape, order, device) = self;
220        let layout = shape.new_contig(None, order);
221        empty_f((layout, device))
222    }
223}
224
225impl<T, D, B> EmptyAPI<(T, D)> for (D, &B)
226where
227    D: DimAPI,
228    B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
229{
230    type Out = Tensor<T, B, IxD>;
231
232    unsafe fn empty_f(self) -> Result<Self::Out> {
233        let (shape, device) = self;
234        let default_order = device.default_order();
235        let layout = shape.new_contig(None, default_order);
236        empty_f((layout, device))
237    }
238}
239
240impl<T, D> EmptyAPI<(T, D)> for (D, FlagOrder)
241where
242    D: DimAPI,
243{
244    type Out = Tensor<T, DeviceCpu, IxD>;
245
246    unsafe fn empty_f(self) -> Result<Self::Out> {
247        let (shape, order) = self;
248        empty_f((shape, order, &DeviceCpu::default()))
249    }
250}
251
252#[duplicate_item(L; [D]; [Layout<D>])]
253impl<T, D> EmptyAPI<(T, D)> for L
254where
255    D: DimAPI,
256{
257    type Out = Tensor<T, DeviceCpu, IxD>;
258
259    unsafe fn empty_f(self) -> Result<Self::Out> {
260        empty_f((self, &DeviceCpu::default()))
261    }
262}
263
264/* #endregion */
265
266/* #region empty_like */
267
268pub trait EmptyLikeAPI<Inp> {
269    type Out;
270
271    /// # Safety
272    ///
273    /// This function is unsafe because it creates a tensor with uninitialized.
274    unsafe fn empty_like_f(self) -> Result<Self::Out>;
275
276    /// # Safety
277    ///
278    /// This function is unsafe because it creates a tensor with uninitialized.
279    unsafe fn empty_like(self) -> Self::Out
280    where
281        Self: Sized,
282    {
283        Self::empty_like_f(self).unwrap()
284    }
285}
286
287/// Uninitialized tensor with the same shape as an input tensor.
288///
289/// # Safety
290///
291/// This function is unsafe because it creates a tensor with uninitialized.
292///
293/// # See also
294///
295/// - [Python array API standard: `empty_like`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.empty_like.html)
296pub unsafe fn empty_like<Args, Inp>(param: Args) -> Args::Out
297where
298    Args: EmptyLikeAPI<Inp>,
299{
300    return EmptyLikeAPI::empty_like(param);
301}
302
303/// # Safety
304///
305/// This function is unsafe because it creates a tensor with uninitialized.
306pub unsafe fn empty_like_f<Args, Inp>(param: Args) -> Result<Args::Out>
307where
308    Args: EmptyLikeAPI<Inp>,
309{
310    return EmptyLikeAPI::empty_like_f(param);
311}
312
313impl<R, T, B, D> EmptyLikeAPI<()> for (&TensorAny<R, T, B, D>, TensorIterOrder, &B)
314where
315    R: DataAPI<Data = B::Raw>,
316    D: DimAPI,
317    B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
318{
319    type Out = Tensor<T, B, D>;
320
321    unsafe fn empty_like_f(self) -> Result<Self::Out> {
322        let (tensor, order, device) = self;
323        let layout = layout_for_array_copy(tensor.layout(), order)?;
324        let idx_max = layout.size();
325        let storage = device.empty_impl(idx_max)?;
326        unsafe { Ok(Tensor::new_unchecked(storage, layout)) }
327    }
328}
329
330impl<R, T, B, D> EmptyLikeAPI<()> for (&TensorAny<R, T, B, D>, &B)
331where
332    R: DataAPI<Data = B::Raw>,
333    D: DimAPI,
334    B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
335{
336    type Out = Tensor<T, B, D>;
337
338    unsafe fn empty_like_f(self) -> Result<Self::Out> {
339        let (tensor, device) = self;
340        empty_like_f((tensor, TensorIterOrder::default(), device))
341    }
342}
343
344impl<R, T, B, D> EmptyLikeAPI<()> for (&TensorAny<R, T, B, D>, TensorIterOrder)
345where
346    R: DataAPI<Data = B::Raw>,
347    D: DimAPI,
348    B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
349{
350    type Out = Tensor<T, B, D>;
351
352    unsafe fn empty_like_f(self) -> Result<Self::Out> {
353        let (tensor, order) = self;
354        let device = tensor.device();
355        empty_like_f((tensor, order, device))
356    }
357}
358
359impl<R, T, B, D> EmptyLikeAPI<()> for &TensorAny<R, T, B, D>
360where
361    R: DataAPI<Data = B::Raw>,
362    D: DimAPI,
363    B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
364{
365    type Out = Tensor<T, B, D>;
366
367    unsafe fn empty_like_f(self) -> Result<Self::Out> {
368        let device = self.device();
369        empty_like_f((self, TensorIterOrder::default(), device))
370    }
371}
372
373/* #endregion */
374
375/* #region eye */
376
377pub trait EyeAPI<Inp> {
378    type Out;
379
380    fn eye_f(self) -> Result<Self::Out>;
381
382    fn eye(self) -> Self::Out
383    where
384        Self: Sized,
385    {
386        Self::eye_f(self).unwrap()
387    }
388}
389
390/// Returns a two-dimensional array with ones on the kth diagonal and zeros
391/// elsewhere.
392///
393/// # See also
394///
395/// - [Python array API standard: `eye`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.eye.html)
396pub fn eye<Args, Inp>(param: Args) -> Args::Out
397where
398    Args: EyeAPI<Inp>,
399{
400    return EyeAPI::eye(param);
401}
402
403pub fn eye_f<Args, Inp>(param: Args) -> Result<Args::Out>
404where
405    Args: EyeAPI<Inp>,
406{
407    return EyeAPI::eye_f(param);
408}
409
410impl<T, B> EyeAPI<(T, B)> for (usize, usize, isize, FlagOrder, &B)
411where
412    T: Num,
413    B: DeviceAPI<T> + DeviceCreationNumAPI<T> + OpAssignAPI<T, Ix1>,
414{
415    type Out = Tensor<T, B, IxD>;
416
417    fn eye_f(self) -> Result<Self::Out> {
418        let (n_rows, n_cols, k, order, device) = self;
419        let layout = match order {
420            RowMajor => [n_rows, n_cols].c(),
421            ColMajor => [n_cols, n_rows].f(),
422        };
423        let mut storage = device.zeros_impl(layout.size())?;
424        let layout_diag = layout.diagonal(Some(k), Some(0), Some(1))?;
425        device.fill(storage.raw_mut(), &layout_diag, T::one())?;
426        unsafe { Ok(Tensor::new_unchecked(storage, layout.into_dim()?)) }
427    }
428}
429
430impl<T, B> EyeAPI<(T, B)> for (usize, usize, isize, &B)
431where
432    T: Num,
433    B: DeviceAPI<T> + DeviceCreationNumAPI<T> + OpAssignAPI<T, Ix1>,
434{
435    type Out = Tensor<T, B, IxD>;
436
437    fn eye_f(self) -> Result<Self::Out> {
438        // (n_rows, n_cols, k, device) -> (n_rows, n_cols, k, C, device)
439        let (n_rows, n_cols, k, device) = self;
440        let default_order = device.default_order();
441        eye_f((n_rows, n_cols, k, default_order, device))
442    }
443}
444
445impl<T, B> EyeAPI<(T, B)> for (usize, &B)
446where
447    T: Num,
448    B: DeviceAPI<T> + DeviceCreationNumAPI<T> + OpAssignAPI<T, Ix1>,
449{
450    type Out = Tensor<T, B, IxD>;
451
452    fn eye_f(self) -> Result<Self::Out> {
453        // (n_rows, n_cols, k, device) -> (n_rows, n_cols, k, C, device)
454        let (n_rows, device) = self;
455        let default_order = device.default_order();
456        eye_f((n_rows, n_rows, 0, default_order, device))
457    }
458}
459
460impl<T> EyeAPI<T> for (usize, usize, isize, FlagOrder)
461where
462    T: Num + Clone + Send + Sync,
463{
464    type Out = Tensor<T, DeviceCpu, IxD>;
465
466    fn eye_f(self) -> Result<Self::Out> {
467        let (n_rows, n_cols, k, order) = self;
468        eye_f((n_rows, n_cols, k, order, &DeviceCpu::default()))
469    }
470}
471
472impl<T> EyeAPI<T> for (usize, usize, isize)
473where
474    T: Num + Clone + Send + Sync,
475{
476    type Out = Tensor<T, DeviceCpu, IxD>;
477
478    fn eye_f(self) -> Result<Self::Out> {
479        // (n_rows, n_cols, k) -> (n_rows, n_cols, k, C)
480        let (n_rows, n_cols, k) = self;
481        let device = DeviceCpu::default();
482        let default_order = device.default_order();
483        eye_f((n_rows, n_cols, k, default_order, &device))
484    }
485}
486
487impl<T> EyeAPI<T> for usize
488where
489    T: Num + Clone + Send + Sync,
490{
491    type Out = Tensor<T, DeviceCpu, IxD>;
492
493    fn eye_f(self) -> Result<Self::Out> {
494        // n_rows -> (n_rows, n_rows, 0, C)
495        let device = DeviceCpu::default();
496        let default_order = device.default_order();
497        eye_f((self, self, 0, default_order, &device))
498    }
499}
500
501/* #endregion */
502
503/* #region full */
504
505pub trait FullAPI<Inp> {
506    type Out;
507
508    fn full_f(self) -> Result<Self::Out>;
509
510    fn full(self) -> Self::Out
511    where
512        Self: Sized,
513    {
514        Self::full_f(self).unwrap()
515    }
516}
517
518/// New tensor having a specified shape and filled with given value.
519///
520/// # See also
521///
522/// - [Python array API standard: `full`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.full.html)
523pub fn full<Args, Inp>(param: Args) -> Args::Out
524where
525    Args: FullAPI<Inp>,
526{
527    return FullAPI::full(param);
528}
529
530pub fn full_f<Args, Inp>(param: Args) -> Result<Args::Out>
531where
532    Args: FullAPI<Inp>,
533{
534    return FullAPI::full_f(param);
535}
536
537impl<T, D, B> FullAPI<(T, D)> for (Layout<D>, T, &B)
538where
539    T: Clone,
540    D: DimAPI,
541    B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
542{
543    type Out = Tensor<T, B, IxD>;
544
545    fn full_f(self) -> Result<Self::Out> {
546        let (layout, fill, device) = self;
547        let idx_max = layout.size();
548        let storage = device.full_impl(idx_max, fill)?;
549        unsafe { Ok(Tensor::new_unchecked(storage, layout.into_dim()?)) }
550    }
551}
552
553impl<T, D, B> FullAPI<(T, D)> for (D, T, FlagOrder, &B)
554where
555    T: Clone,
556    D: DimAPI,
557    B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
558{
559    type Out = Tensor<T, B, IxD>;
560
561    fn full_f(self) -> Result<Self::Out> {
562        let (shape, fill, order, device) = self;
563        let layout = shape.new_contig(None, order);
564        full_f((layout, fill, device))
565    }
566}
567
568impl<T, D, B> FullAPI<(T, D)> for (D, T, &B)
569where
570    T: Clone,
571    D: DimAPI,
572    B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
573{
574    type Out = Tensor<T, B, IxD>;
575
576    fn full_f(self) -> Result<Self::Out> {
577        let (shape, fill, device) = self;
578        let default_order = device.default_order();
579        let layout = shape.new_contig(None, default_order);
580        full_f((layout, fill, device))
581    }
582}
583
584impl<T, D> FullAPI<(T, D)> for (D, T, FlagOrder)
585where
586    T: Clone,
587    D: DimAPI,
588{
589    type Out = Tensor<T, DeviceCpu, IxD>;
590
591    fn full_f(self) -> Result<Self::Out> {
592        let (shape, fill, order) = self;
593        full_f((shape, fill, order, &DeviceCpu::default()))
594    }
595}
596
597#[duplicate_item(L; [D]; [Layout<D>])]
598impl<T, D> FullAPI<(T, D)> for (L, T)
599where
600    T: Clone,
601    D: DimAPI,
602{
603    type Out = Tensor<T, DeviceCpu, IxD>;
604
605    fn full_f(self) -> Result<Self::Out> {
606        let (shape, fill) = self;
607        full_f((shape, fill, &DeviceCpu::default()))
608    }
609}
610
611/* #endregion */
612
613/* #region full_like */
614
615pub trait FullLikeAPI<Inp> {
616    type Out;
617
618    fn full_like_f(self) -> Result<Self::Out>;
619
620    fn full_like(self) -> Self::Out
621    where
622        Self: Sized,
623    {
624        Self::full_like_f(self).unwrap()
625    }
626}
627
628/// New tensor filled with given value and having the same shape as an input
629/// tensor.
630///
631/// # See also
632///
633/// - [Python array API standard: `full_like`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.full_like.html)
634pub fn full_like<Args, Inp>(param: Args) -> Args::Out
635where
636    Args: FullLikeAPI<Inp>,
637{
638    return FullLikeAPI::full_like(param);
639}
640
641pub fn full_like_f<Args, Inp>(param: Args) -> Result<Args::Out>
642where
643    Args: FullLikeAPI<Inp>,
644{
645    return FullLikeAPI::full_like_f(param);
646}
647
648impl<R, T, B, D> FullLikeAPI<()> for (&TensorAny<R, T, B, D>, T, TensorIterOrder, &B)
649where
650    T: Clone,
651    R: DataAPI<Data = B::Raw>,
652    D: DimAPI,
653    B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
654{
655    type Out = Tensor<T, B, D>;
656
657    fn full_like_f(self) -> Result<Self::Out> {
658        let (tensor, fill, order, device) = self;
659        let layout = layout_for_array_copy(tensor.layout(), order)?;
660        let idx_max = layout.size();
661        let storage = device.full_impl(idx_max, fill)?;
662        unsafe { Ok(Tensor::new_unchecked(storage, layout)) }
663    }
664}
665
666impl<R, T, B, D> FullLikeAPI<()> for (&TensorAny<R, T, B, D>, T, &B)
667where
668    T: Clone,
669    R: DataAPI<Data = B::Raw>,
670    D: DimAPI,
671    B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
672{
673    type Out = Tensor<T, B, D>;
674
675    fn full_like_f(self) -> Result<Self::Out> {
676        let (tensor, fill, device) = self;
677        full_like_f((tensor, fill, TensorIterOrder::default(), device))
678    }
679}
680
681impl<R, T, B, D> FullLikeAPI<()> for (&TensorAny<R, T, B, D>, T, TensorIterOrder)
682where
683    T: Clone,
684    R: DataAPI<Data = B::Raw>,
685    D: DimAPI,
686    B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
687{
688    type Out = Tensor<T, B, D>;
689
690    fn full_like_f(self) -> Result<Self::Out> {
691        let (tensor, fill, order) = self;
692        let device = tensor.device();
693        full_like_f((tensor, fill, order, device))
694    }
695}
696
697impl<R, T, B, D> FullLikeAPI<()> for (&TensorAny<R, T, B, D>, T)
698where
699    T: Clone,
700    R: DataAPI<Data = B::Raw>,
701    D: DimAPI,
702    B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
703{
704    type Out = Tensor<T, B, D>;
705
706    fn full_like_f(self) -> Result<Self::Out> {
707        let (tensor, fill) = self;
708        let device = tensor.device();
709        full_like_f((tensor, fill, TensorIterOrder::default(), device))
710    }
711}
712
713impl<R, T, B, D> TensorAny<R, T, B, D>
714where
715    R: DataAPI<Data = B::Raw>,
716    T: Clone,
717    D: DimAPI,
718    B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
719{
720    pub fn full_like(&self, fill: T) -> Tensor<T, B, D> {
721        full_like((self, fill))
722    }
723
724    pub fn full_like_f(&self, fill: T) -> Result<Tensor<T, B, D>> {
725        full_like_f((self, fill))
726    }
727}
728
729/* #endregion */
730
731/* #region linspace */
732
733pub trait LinspaceAPI<Inp> {
734    type Out;
735
736    fn linspace_f(self) -> Result<Self::Out>;
737
738    fn linspace(self) -> Self::Out
739    where
740        Self: Sized,
741    {
742        Self::linspace_f(self).unwrap()
743    }
744}
745
746/// Evenly spaced numbers over a specified interval.
747///
748/// For boundary condition, current implementation is similar to numpy,
749/// where `n = 0` will return an empty array, and `n = 1` will return an
750/// array with starting value.
751///
752/// # See also
753///
754/// - [Python array API standard: `linspace`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.linspace.html)
755pub fn linspace<Args, Inp>(param: Args) -> Args::Out
756where
757    Args: LinspaceAPI<Inp>,
758{
759    return LinspaceAPI::linspace(param);
760}
761
762pub fn linspace_f<Args, Inp>(param: Args) -> Result<Args::Out>
763where
764    Args: LinspaceAPI<Inp>,
765{
766    return LinspaceAPI::linspace_f(param);
767}
768
769impl<T, B> LinspaceAPI<(T, B)> for (T, T, usize, bool, &B)
770where
771    T: ComplexFloat,
772    B: DeviceAPI<T> + DeviceCreationComplexFloatAPI<T>,
773{
774    type Out = Tensor<T, B, IxD>;
775
776    fn linspace_f(self) -> Result<Self::Out> {
777        let (start, end, n, endpoint, device) = self;
778        let data = B::linspace_impl(device, start, end, n, endpoint)?;
779        let layout = vec![data.len()].into();
780        unsafe { Ok(Tensor::new_unchecked(data, layout)) }
781    }
782}
783
784impl<T, B> LinspaceAPI<(T, B)> for (T, T, usize, &B)
785where
786    T: ComplexFloat,
787    B: DeviceAPI<T> + DeviceCreationComplexFloatAPI<T>,
788{
789    type Out = Tensor<T, B, IxD>;
790
791    fn linspace_f(self) -> Result<Self::Out> {
792        // (start, end, n, device) -> (start, end, n, true, device)
793        let (start, end, n, device) = self;
794        linspace_f((start, end, n, true, device))
795    }
796}
797
798impl<T> LinspaceAPI<T> for (T, T, usize, bool)
799where
800    T: ComplexFloat + Send + Sync,
801{
802    type Out = Tensor<T, DeviceCpu, IxD>;
803
804    fn linspace_f(self) -> Result<Self::Out> {
805        // (start, end, n, endpoint) -> (start, end, n, endpoint, device)
806        let (start, end, n, endpoint) = self;
807        linspace_f((start, end, n, endpoint, &DeviceCpu::default()))
808    }
809}
810
811impl<T> LinspaceAPI<T> for (T, T, usize)
812where
813    T: ComplexFloat + Send + Sync,
814{
815    type Out = Tensor<T, DeviceCpu, IxD>;
816
817    fn linspace_f(self) -> Result<Self::Out> {
818        // (start, end, n) -> (start, end, n, true, device)
819        let (start, end, n) = self;
820        linspace_f((start, end, n, true, &DeviceCpu::default()))
821    }
822}
823
824/* #endregion */
825
826/* #region ones */
827
828pub trait OnesAPI<Inp> {
829    type Out;
830
831    fn ones_f(self) -> Result<Self::Out>;
832
833    fn ones(self) -> Self::Out
834    where
835        Self: Sized,
836    {
837        Self::ones_f(self).unwrap()
838    }
839}
840
841/// New tensor filled with ones and having a specified shape.
842///
843/// # See also
844///
845/// - [Python array API standard: `ones`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.ones.html)
846pub fn ones<Args, Inp>(param: Args) -> Args::Out
847where
848    Args: OnesAPI<Inp>,
849{
850    return OnesAPI::ones(param);
851}
852
853pub fn ones_f<Args, Inp>(param: Args) -> Result<Args::Out>
854where
855    Args: OnesAPI<Inp>,
856{
857    return OnesAPI::ones_f(param);
858}
859
860impl<T, D, B> OnesAPI<(T, D)> for (Layout<D>, &B)
861where
862    T: Num,
863    D: DimAPI,
864    B: DeviceAPI<T> + DeviceCreationNumAPI<T>,
865{
866    type Out = Tensor<T, B, IxD>;
867
868    fn ones_f(self) -> Result<Self::Out> {
869        let (layout, device) = self;
870        let (_, idx_max) = layout.bounds_index()?;
871        let storage = device.ones_impl(idx_max)?;
872        unsafe { Ok(Tensor::new_unchecked(storage, layout.into_dim()?)) }
873    }
874}
875
876impl<T, D, B> OnesAPI<(T, D)> for (D, FlagOrder, &B)
877where
878    T: Num,
879    D: DimAPI,
880    B: DeviceAPI<T> + DeviceCreationNumAPI<T>,
881{
882    type Out = Tensor<T, B, IxD>;
883
884    fn ones_f(self) -> Result<Self::Out> {
885        let (shape, order, device) = self;
886        let layout = shape.new_contig(None, order);
887        ones_f((layout, device))
888    }
889}
890
891impl<T, D, B> OnesAPI<(T, D)> for (D, &B)
892where
893    T: Num,
894    D: DimAPI,
895    B: DeviceAPI<T> + DeviceCreationNumAPI<T>,
896{
897    type Out = Tensor<T, B, IxD>;
898
899    fn ones_f(self) -> Result<Self::Out> {
900        let (shape, device) = self;
901        let default_order = device.default_order();
902        let layout = shape.new_contig(None, default_order);
903        ones_f((layout, device))
904    }
905}
906
907impl<T, D> OnesAPI<(T, D)> for (D, FlagOrder)
908where
909    T: Num + Clone,
910    D: DimAPI,
911{
912    type Out = Tensor<T, DeviceCpu, IxD>;
913
914    fn ones_f(self) -> Result<Self::Out> {
915        let (shape, order) = self;
916        ones_f((shape, order, &DeviceCpu::default()))
917    }
918}
919
920#[duplicate_item(L; [D]; [Layout<D>])]
921impl<T, D> OnesAPI<(T, D)> for L
922where
923    T: Num + Clone,
924    D: DimAPI,
925{
926    type Out = Tensor<T, DeviceCpu, IxD>;
927
928    fn ones_f(self) -> Result<Self::Out> {
929        ones_f((self, &DeviceCpu::default()))
930    }
931}
932
933/* #endregion */
934
935/* #region ones_like */
936
937pub trait OnesLikeAPI<Inp> {
938    type Out;
939
940    fn ones_like_f(self) -> Result<Self::Out>;
941
942    fn ones_like(self) -> Self::Out
943    where
944        Self: Sized,
945    {
946        Self::ones_like_f(self).unwrap()
947    }
948}
949
950/// New tensor filled with ones and having the same shape as an input
951/// tensor.
952///
953/// # See also
954///
955/// - [Python array API standard: `ones_like`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.ones_like.html)
956pub fn ones_like<Args, Inp>(param: Args) -> Args::Out
957where
958    Args: OnesLikeAPI<Inp>,
959{
960    return OnesLikeAPI::ones_like(param);
961}
962
963pub fn ones_like_f<Args, Inp>(param: Args) -> Result<Args::Out>
964where
965    Args: OnesLikeAPI<Inp>,
966{
967    return OnesLikeAPI::ones_like_f(param);
968}
969
970impl<R, T, B, D> OnesLikeAPI<()> for (&TensorAny<R, T, B, D>, TensorIterOrder, &B)
971where
972    R: DataAPI<Data = B::Raw>,
973    T: Num,
974    D: DimAPI,
975    B: DeviceAPI<T> + DeviceCreationNumAPI<T>,
976{
977    type Out = Tensor<T, B, D>;
978
979    fn ones_like_f(self) -> Result<Self::Out> {
980        let (tensor, order, device) = self;
981        let layout = layout_for_array_copy(tensor.layout(), order)?;
982        let idx_max = layout.size();
983        let storage = device.ones_impl(idx_max)?;
984        unsafe { Ok(Tensor::new_unchecked(storage, layout)) }
985    }
986}
987
988impl<R, T, B, D> OnesLikeAPI<()> for (&TensorAny<R, T, B, D>, &B)
989where
990    R: DataAPI<Data = B::Raw>,
991    T: Num,
992    D: DimAPI,
993    B: DeviceAPI<T> + DeviceCreationNumAPI<T>,
994{
995    type Out = Tensor<T, B, D>;
996
997    fn ones_like_f(self) -> Result<Self::Out> {
998        let (tensor, device) = self;
999        ones_like_f((tensor, TensorIterOrder::default(), device))
1000    }
1001}
1002
1003impl<R, T, B, D> OnesLikeAPI<()> for (&TensorAny<R, T, B, D>, TensorIterOrder)
1004where
1005    R: DataAPI<Data = B::Raw>,
1006    T: Num,
1007    D: DimAPI,
1008    B: DeviceAPI<T> + DeviceCreationNumAPI<T>,
1009{
1010    type Out = Tensor<T, B, D>;
1011
1012    fn ones_like_f(self) -> Result<Self::Out> {
1013        let (tensor, order) = self;
1014        let device = tensor.device();
1015        ones_like_f((tensor, order, device))
1016    }
1017}
1018
1019impl<R, T, B, D> OnesLikeAPI<()> for &TensorAny<R, T, B, D>
1020where
1021    R: DataAPI<Data = B::Raw>,
1022    T: Num,
1023    D: DimAPI,
1024    B: DeviceAPI<T> + DeviceCreationNumAPI<T>,
1025{
1026    type Out = Tensor<T, B, D>;
1027
1028    fn ones_like_f(self) -> Result<Self::Out> {
1029        let device = self.device();
1030        ones_like_f((self, TensorIterOrder::default(), device))
1031    }
1032}
1033
1034impl<R, T, B, D> TensorAny<R, T, B, D>
1035where
1036    R: DataAPI<Data = B::Raw>,
1037    D: DimAPI,
1038    T: Num,
1039    B: DeviceAPI<T> + DeviceCreationNumAPI<T>,
1040{
1041    /// New tensor filled with ones and having the same shape as an input
1042    /// tensor.
1043    ///
1044    /// # See also
1045    ///
1046    /// [`ones_like`]
1047    pub fn ones_like(&self) -> Tensor<T, B, D> {
1048        ones_like((self, TensorIterOrder::default(), self.device()))
1049    }
1050
1051    pub fn ones_like_f(&self) -> Result<Tensor<T, B, D>> {
1052        ones_like_f((self, TensorIterOrder::default(), self.device()))
1053    }
1054}
1055
1056/* #endregion */
1057
1058/* #region zeros */
1059
1060pub trait ZerosAPI<Inp> {
1061    type Out;
1062
1063    fn zeros_f(self) -> Result<Self::Out>;
1064
1065    fn zeros(self) -> Self::Out
1066    where
1067        Self: Sized,
1068    {
1069        Self::zeros_f(self).unwrap()
1070    }
1071}
1072
1073/// New tensor filled with zeros and having a specified shape.
1074///
1075/// # See also
1076///
1077/// - [Python array API standard: `zeros`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.zeros.html)
1078pub fn zeros<Args, Inp>(param: Args) -> Args::Out
1079where
1080    Args: ZerosAPI<Inp>,
1081{
1082    return ZerosAPI::zeros(param);
1083}
1084
1085pub fn zeros_f<Args, Inp>(param: Args) -> Result<Args::Out>
1086where
1087    Args: ZerosAPI<Inp>,
1088{
1089    return ZerosAPI::zeros_f(param);
1090}
1091
1092impl<T, D, B> ZerosAPI<(T, D)> for (Layout<D>, &B)
1093where
1094    T: Num,
1095    D: DimAPI,
1096    B: DeviceAPI<T> + DeviceCreationNumAPI<T>,
1097{
1098    type Out = Tensor<T, B, IxD>;
1099
1100    fn zeros_f(self) -> Result<Self::Out> {
1101        let (layout, device) = self;
1102        let (_, idx_max) = layout.bounds_index()?;
1103        let storage = B::zeros_impl(device, idx_max)?;
1104        unsafe { Ok(Tensor::new_unchecked(storage, layout.into_dim()?)) }
1105    }
1106}
1107
1108impl<T, D, B> ZerosAPI<(T, D)> for (D, FlagOrder, &B)
1109where
1110    T: Num,
1111    D: DimAPI,
1112    B: DeviceAPI<T> + DeviceCreationNumAPI<T>,
1113{
1114    type Out = Tensor<T, B, IxD>;
1115
1116    fn zeros_f(self) -> Result<Self::Out> {
1117        let (shape, order, device) = self;
1118        let layout = shape.new_contig(None, order);
1119        zeros_f((layout, device))
1120    }
1121}
1122
1123impl<T, D, B> ZerosAPI<(T, D)> for (D, &B)
1124where
1125    T: Num,
1126    D: DimAPI,
1127    B: DeviceAPI<T> + DeviceCreationNumAPI<T>,
1128{
1129    type Out = Tensor<T, B, IxD>;
1130
1131    fn zeros_f(self) -> Result<Self::Out> {
1132        let (shape, device) = self;
1133        let default_order = device.default_order();
1134        let layout = shape.new_contig(None, default_order);
1135        zeros_f((layout, device))
1136    }
1137}
1138
1139impl<T, D> ZerosAPI<(T, D)> for (D, FlagOrder)
1140where
1141    T: Num + Clone,
1142    D: DimAPI,
1143{
1144    type Out = Tensor<T, DeviceCpu, IxD>;
1145
1146    fn zeros_f(self) -> Result<Self::Out> {
1147        let (shape, order) = self;
1148        zeros_f((shape, order, &DeviceCpu::default()))
1149    }
1150}
1151
1152#[duplicate_item(L; [D]; [Layout<D>])]
1153impl<T, D> ZerosAPI<(T, D)> for L
1154where
1155    T: Num + Clone,
1156    D: DimAPI,
1157{
1158    type Out = Tensor<T, DeviceCpu, IxD>;
1159
1160    fn zeros_f(self) -> Result<Self::Out> {
1161        zeros_f((self, &DeviceCpu::default()))
1162    }
1163}
1164
1165/* #endregion */
1166
1167/* #region zeros_like */
1168
1169pub trait ZerosLikeAPI<Inp> {
1170    type Out;
1171
1172    fn zeros_like_f(self) -> Result<Self::Out>;
1173
1174    fn zeros_like(self) -> Self::Out
1175    where
1176        Self: Sized,
1177    {
1178        Self::zeros_like_f(self).unwrap()
1179    }
1180}
1181
1182/// New tensor filled with zeros and having the same shape as an input
1183/// tensor.
1184///
1185/// # See also
1186///
1187/// - [Python array API standard: `zeros_like`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.zeros_like.html)
1188pub fn zeros_like<Args, Inp>(param: Args) -> Args::Out
1189where
1190    Args: ZerosLikeAPI<Inp>,
1191{
1192    return ZerosLikeAPI::zeros_like(param);
1193}
1194
1195pub fn zeros_like_f<Args, Inp>(param: Args) -> Result<Args::Out>
1196where
1197    Args: ZerosLikeAPI<Inp>,
1198{
1199    return ZerosLikeAPI::zeros_like_f(param);
1200}
1201
1202impl<R, T, B, D> ZerosLikeAPI<()> for (&TensorAny<R, T, B, D>, TensorIterOrder, &B)
1203where
1204    R: DataAPI<Data = B::Raw>,
1205    T: Num,
1206    D: DimAPI,
1207    B: DeviceAPI<T> + DeviceCreationNumAPI<T>,
1208{
1209    type Out = Tensor<T, B, D>;
1210
1211    fn zeros_like_f(self) -> Result<Self::Out> {
1212        let (tensor, order, device) = self;
1213        let layout = layout_for_array_copy(tensor.layout(), order)?;
1214        let idx_max = layout.size();
1215        let storage = B::zeros_impl(device, idx_max)?;
1216        unsafe { Ok(Tensor::new_unchecked(storage, layout)) }
1217    }
1218}
1219
1220impl<R, T, B, D> ZerosLikeAPI<()> for (&TensorAny<R, T, B, D>, &B)
1221where
1222    R: DataAPI<Data = B::Raw>,
1223    T: Num,
1224    D: DimAPI,
1225    B: DeviceAPI<T> + DeviceCreationNumAPI<T>,
1226{
1227    type Out = Tensor<T, B, D>;
1228
1229    fn zeros_like_f(self) -> Result<Self::Out> {
1230        let (tensor, device) = self;
1231        zeros_like_f((tensor, TensorIterOrder::default(), device))
1232    }
1233}
1234
1235impl<R, T, B, D> ZerosLikeAPI<()> for (&TensorAny<R, T, B, D>, TensorIterOrder)
1236where
1237    R: DataAPI<Data = B::Raw>,
1238    T: Num,
1239    D: DimAPI,
1240    B: DeviceAPI<T> + DeviceCreationNumAPI<T>,
1241{
1242    type Out = Tensor<T, B, D>;
1243
1244    fn zeros_like_f(self) -> Result<Self::Out> {
1245        let (tensor, order) = self;
1246        let device = tensor.device();
1247        zeros_like_f((tensor, order, device))
1248    }
1249}
1250
1251impl<R, T, B, D> ZerosLikeAPI<()> for &TensorAny<R, T, B, D>
1252where
1253    R: DataAPI<Data = B::Raw>,
1254    T: Num,
1255    D: DimAPI,
1256    B: DeviceAPI<T> + DeviceCreationNumAPI<T>,
1257{
1258    type Out = Tensor<T, B, D>;
1259
1260    fn zeros_like_f(self) -> Result<Self::Out> {
1261        zeros_like_f((self, TensorIterOrder::default(), self.device()))
1262    }
1263}
1264
1265impl<R, T, B, D> TensorAny<R, T, B, D>
1266where
1267    R: DataAPI<Data = B::Raw>,
1268    D: DimAPI,
1269    T: Num,
1270    B: DeviceAPI<T> + DeviceCreationNumAPI<T>,
1271{
1272    /// New tensor filled with zeros and having the same shape as an input
1273    /// tensor.
1274    ///
1275    /// # See also
1276    ///
1277    /// [`zeros_like`]
1278    pub fn zeros_like(&self) -> Tensor<T, B, D> {
1279        zeros_like((self, TensorIterOrder::default(), self.device()))
1280    }
1281
1282    pub fn zeros_like_f(&self) -> Result<Tensor<T, B, D>> {
1283        zeros_like_f((self, TensorIterOrder::default(), self.device()))
1284    }
1285}
1286
1287/* #endregion */
1288
1289/* #region tril */
1290
1291pub trait TrilAPI<Inp> {
1292    type Out;
1293
1294    fn tril_f(self) -> Result<Self::Out>;
1295
1296    fn tril(self) -> Self::Out
1297    where
1298        Self: Sized,
1299    {
1300        Self::tril_f(self).unwrap()
1301    }
1302}
1303
1304/// Returns the lower triangular part of a matrix (or a stack of matrices) x.
1305///
1306/// # See also
1307///
1308/// - [Python array API standard: `tril`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.tril.html)
1309pub fn tril<Args, Inp>(param: Args) -> Args::Out
1310where
1311    Args: TrilAPI<Inp>,
1312{
1313    return TrilAPI::tril(param);
1314}
1315
1316pub fn tril_f<Args, Inp>(param: Args) -> Result<Args::Out>
1317where
1318    Args: TrilAPI<Inp>,
1319{
1320    return TrilAPI::tril_f(param);
1321}
1322
1323impl<T, D, B> TrilAPI<()> for (TensorView<'_, T, B, D>, isize)
1324where
1325    T: Num + Clone,
1326    D: DimAPI,
1327    B: DeviceAPI<T>
1328        + DeviceCreationTriAPI<T>
1329        + DeviceCreationAnyAPI<T>
1330        + OpAssignArbitaryAPI<T, D, D>
1331        + OpAssignAPI<T, D>,
1332    B::Raw: Clone,
1333{
1334    type Out = Tensor<T, B, D>;
1335
1336    fn tril_f(self) -> Result<Self::Out> {
1337        let (x, k) = self;
1338        let default_order = x.device().default_order();
1339        let mut x = x.into_contig_f(default_order)?;
1340        let device = x.device().clone();
1341        let layout = x.layout().clone();
1342        device.tril_impl(x.raw_mut(), &layout, k)?;
1343        Ok(x)
1344    }
1345}
1346
1347impl<T, D, B> TrilAPI<()> for TensorView<'_, T, B, D>
1348where
1349    T: Num + Clone,
1350    D: DimAPI,
1351    B: DeviceAPI<T>
1352        + DeviceCreationTriAPI<T>
1353        + DeviceCreationAnyAPI<T>
1354        + OpAssignArbitaryAPI<T, D, D>
1355        + OpAssignAPI<T, D>,
1356    B::Raw: Clone,
1357{
1358    type Out = Tensor<T, B, D>;
1359
1360    fn tril_f(self) -> Result<Self::Out> {
1361        tril_f((self, 0))
1362    }
1363}
1364
1365impl<'a, T, D, B> TrilAPI<()> for (TensorMut<'a, T, B, D>, isize)
1366where
1367    T: Num + Clone,
1368    D: DimAPI,
1369    B: DeviceAPI<T>
1370        + DeviceCreationTriAPI<T>
1371        + DeviceCreationAnyAPI<T>
1372        + OpAssignArbitaryAPI<T, D, D>
1373        + OpAssignAPI<T, D>,
1374{
1375    type Out = TensorMut<'a, T, B, D>;
1376
1377    fn tril_f(self) -> Result<Self::Out> {
1378        let (mut x, k) = self;
1379        let device = x.device().clone();
1380        let layout = x.layout().clone();
1381        device.tril_impl(x.raw_mut(), &layout, k)?;
1382        Ok(x)
1383    }
1384}
1385
1386impl<'a, T, D, B> TrilAPI<()> for TensorMut<'a, T, B, D>
1387where
1388    T: Num + Clone,
1389    D: DimAPI,
1390    B: DeviceAPI<T>
1391        + DeviceCreationTriAPI<T>
1392        + DeviceCreationAnyAPI<T>
1393        + OpAssignArbitaryAPI<T, D, D>
1394        + OpAssignAPI<T, D>,
1395{
1396    type Out = TensorMut<'a, T, B, D>;
1397
1398    fn tril_f(self) -> Result<Self::Out> {
1399        tril_f((self, 0))
1400    }
1401}
1402
1403impl<T, D, B> TrilAPI<()> for (Tensor<T, B, D>, isize)
1404where
1405    T: Num + Clone,
1406    D: DimAPI,
1407    B: DeviceAPI<T> + DeviceCreationTriAPI<T> + DeviceCreationAnyAPI<T>,
1408{
1409    type Out = Tensor<T, B, D>;
1410
1411    fn tril_f(self) -> Result<Self::Out> {
1412        let (mut x, k) = self;
1413        let device = x.device().clone();
1414        let layout = x.layout().clone();
1415        device.tril_impl(x.raw_mut(), &layout, k)?;
1416        Ok(x)
1417    }
1418}
1419
1420impl<T, D, B> TrilAPI<()> for Tensor<T, B, D>
1421where
1422    T: Num + Clone,
1423    D: DimAPI,
1424    B: DeviceAPI<T> + DeviceCreationTriAPI<T> + DeviceCreationAnyAPI<T>,
1425{
1426    type Out = Tensor<T, B, D>;
1427
1428    fn tril_f(self) -> Result<Self::Out> {
1429        tril_f((self, 0))
1430    }
1431}
1432
1433impl<R, T, D, B> TrilAPI<()> for (&TensorAny<R, T, B, D>, isize)
1434where
1435    R: DataAPI<Data = B::Raw>,
1436    T: Num + Clone,
1437    D: DimAPI,
1438    B: DeviceAPI<T>
1439        + DeviceCreationTriAPI<T>
1440        + DeviceCreationAnyAPI<T>
1441        + OpAssignArbitaryAPI<T, D, D>
1442        + OpAssignAPI<T, D>,
1443    B::Raw: Clone,
1444{
1445    type Out = Tensor<T, B, D>;
1446
1447    fn tril_f(self) -> Result<Self::Out> {
1448        let (x, k) = self;
1449        tril_f((x.view(), k))
1450    }
1451}
1452
1453impl<R, T, D, B> TrilAPI<()> for &TensorAny<R, T, B, D>
1454where
1455    R: DataAPI<Data = B::Raw>,
1456    T: Num + Clone,
1457    D: DimAPI,
1458    B: DeviceAPI<T>
1459        + DeviceCreationTriAPI<T>
1460        + DeviceCreationAnyAPI<T>
1461        + OpAssignArbitaryAPI<T, D, D>
1462        + OpAssignAPI<T, D>,
1463    B::Raw: Clone,
1464{
1465    type Out = Tensor<T, B, D>;
1466
1467    fn tril_f(self) -> Result<Self::Out> {
1468        tril_f((self.view(), 0))
1469    }
1470}
1471
1472/* #endregion */
1473
1474/* #region triu */
1475
1476pub trait TriuAPI<Inp> {
1477    type Out;
1478
1479    fn triu_f(self) -> Result<Self::Out>;
1480
1481    fn triu(self) -> Self::Out
1482    where
1483        Self: Sized,
1484    {
1485        Self::triu_f(self).unwrap()
1486    }
1487}
1488
1489/// Returns the upper triangular part of a matrix (or a stack of matrices) x.
1490///
1491/// # See also
1492///
1493/// - [Python array API standard: `triu`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.triu.html)
1494pub fn triu<Args, Inp>(param: Args) -> Args::Out
1495where
1496    Args: TriuAPI<Inp>,
1497{
1498    return TriuAPI::triu(param);
1499}
1500
1501pub fn triu_f<Args, Inp>(param: Args) -> Result<Args::Out>
1502where
1503    Args: TriuAPI<Inp>,
1504{
1505    return TriuAPI::triu_f(param);
1506}
1507
1508impl<T, D, B> TriuAPI<()> for (TensorView<'_, T, B, D>, isize)
1509where
1510    T: Num + Clone,
1511    D: DimAPI,
1512    B: DeviceAPI<T>
1513        + DeviceCreationTriAPI<T>
1514        + DeviceCreationAnyAPI<T>
1515        + OpAssignArbitaryAPI<T, D, D>
1516        + OpAssignAPI<T, D>,
1517    B::Raw: Clone,
1518{
1519    type Out = Tensor<T, B, D>;
1520
1521    fn triu_f(self) -> Result<Self::Out> {
1522        let (x, k) = self;
1523        let default_order = x.device().default_order();
1524        let mut x = x.into_contig_f(default_order)?;
1525        let device = x.device().clone();
1526        let layout = x.layout().clone();
1527        device.triu_impl(x.raw_mut(), &layout, k)?;
1528        Ok(x)
1529    }
1530}
1531
1532impl<T, D, B> TriuAPI<()> for TensorView<'_, T, B, D>
1533where
1534    T: Num + Clone,
1535    D: DimAPI,
1536    B: DeviceAPI<T>
1537        + DeviceCreationTriAPI<T>
1538        + DeviceCreationAnyAPI<T>
1539        + OpAssignArbitaryAPI<T, D, D>
1540        + OpAssignAPI<T, D>,
1541    B::Raw: Clone,
1542{
1543    type Out = Tensor<T, B, D>;
1544
1545    fn triu_f(self) -> Result<Self::Out> {
1546        triu_f((self, 0))
1547    }
1548}
1549
1550impl<'a, T, D, B> TriuAPI<()> for (TensorMut<'a, T, B, D>, isize)
1551where
1552    T: Num + Clone,
1553    D: DimAPI,
1554    B: DeviceAPI<T>
1555        + DeviceCreationTriAPI<T>
1556        + DeviceCreationAnyAPI<T>
1557        + OpAssignArbitaryAPI<T, D, D>
1558        + OpAssignAPI<T, D>,
1559{
1560    type Out = TensorMut<'a, T, B, D>;
1561
1562    fn triu_f(self) -> Result<Self::Out> {
1563        let (mut x, k) = self;
1564        let device = x.device().clone();
1565        let layout = x.layout().clone();
1566        device.triu_impl(x.raw_mut(), &layout, k)?;
1567        Ok(x)
1568    }
1569}
1570
1571impl<'a, T, D, B> TriuAPI<()> for TensorMut<'a, T, B, D>
1572where
1573    T: Num + Clone,
1574    D: DimAPI,
1575    B: DeviceAPI<T>
1576        + DeviceCreationTriAPI<T>
1577        + DeviceCreationAnyAPI<T>
1578        + OpAssignArbitaryAPI<T, D, D>
1579        + OpAssignAPI<T, D>,
1580{
1581    type Out = TensorMut<'a, T, B, D>;
1582
1583    fn triu_f(self) -> Result<Self::Out> {
1584        triu_f((self, 0))
1585    }
1586}
1587
1588impl<T, D, B> TriuAPI<()> for (Tensor<T, B, D>, isize)
1589where
1590    T: Num + Clone,
1591    D: DimAPI,
1592    B: DeviceAPI<T> + DeviceCreationTriAPI<T> + DeviceCreationAnyAPI<T>,
1593{
1594    type Out = Tensor<T, B, D>;
1595
1596    fn triu_f(self) -> Result<Self::Out> {
1597        let (mut x, k) = self;
1598        let device = x.device().clone();
1599        let layout = x.layout().clone();
1600        device.triu_impl(x.raw_mut(), &layout, k)?;
1601        Ok(x)
1602    }
1603}
1604
1605impl<T, D, B> TriuAPI<()> for Tensor<T, B, D>
1606where
1607    T: Num + Clone,
1608    D: DimAPI,
1609    B: DeviceAPI<T> + DeviceCreationTriAPI<T> + DeviceCreationAnyAPI<T>,
1610{
1611    type Out = Tensor<T, B, D>;
1612
1613    fn triu_f(self) -> Result<Self::Out> {
1614        triu_f((self, 0))
1615    }
1616}
1617
1618impl<R, T, D, B> TriuAPI<()> for (&TensorAny<R, T, B, D>, isize)
1619where
1620    R: DataAPI<Data = B::Raw>,
1621    T: Num + Clone,
1622    D: DimAPI,
1623    B: DeviceAPI<T>
1624        + DeviceCreationTriAPI<T>
1625        + DeviceCreationAnyAPI<T>
1626        + OpAssignArbitaryAPI<T, D, D>
1627        + OpAssignAPI<T, D>,
1628    B::Raw: Clone,
1629{
1630    type Out = Tensor<T, B, D>;
1631
1632    fn triu_f(self) -> Result<Self::Out> {
1633        let (x, k) = self;
1634        triu_f((x.view(), k))
1635    }
1636}
1637
1638impl<R, T, D, B> TriuAPI<()> for &TensorAny<R, T, B, D>
1639where
1640    R: DataAPI<Data = B::Raw>,
1641    T: Num + Clone,
1642    D: DimAPI,
1643    B: DeviceAPI<T>
1644        + DeviceCreationTriAPI<T>
1645        + DeviceCreationAnyAPI<T>
1646        + OpAssignArbitaryAPI<T, D, D>
1647        + OpAssignAPI<T, D>,
1648    B::Raw: Clone,
1649{
1650    type Out = Tensor<T, B, D>;
1651
1652    fn triu_f(self) -> Result<Self::Out> {
1653        triu_f((self.view(), 0))
1654    }
1655}
1656
1657/* #endregion */
1658
1659#[cfg(test)]
1660mod test {
1661    use super::*;
1662    use num::complex::Complex32;
1663
1664    #[test]
1665    fn playground() {
1666        let a = arange((2.5, 3.2, 0.02));
1667        println!("{a:6.3?}");
1668        let a = arange(15.0);
1669        println!("{a:6.3?}");
1670        let a = arange((15.0, &DeviceCpu::default()));
1671        println!("{a:6.3?}");
1672        let a: Tensor<f64, _> = unsafe { empty(([15, 18].f(), &DeviceCpuSerial::default())) };
1673        println!("{a:6.3?}");
1674        let a = unsafe { a.empty_like() };
1675        println!("{a:6.3?}");
1676        let a = unsafe { empty_like((&a, TensorIterOrder::C)) };
1677        println!("{a:6.3?}");
1678        let a: Tensor<f64, _> = eye(3);
1679        println!("{a:6.3?}");
1680        let a = full(([2, 2].f(), 3.16));
1681        println!("{a:6.3?}");
1682        let a = full_like((&a, 2.71));
1683        println!("{a:6.3?}");
1684        let a = a.full_like(2.71);
1685        println!("{a:6.3?}");
1686        let a = linspace((3.2, 4.7, 12));
1687        println!("{a:6.3?}");
1688        let a = linspace((Complex32::new(1.8, 7.5), Complex32::new(-8.9, 1.6), 12));
1689        println!("{a:6.3?}");
1690        let a: Tensor<f64> = ones(vec![2, 2]);
1691        println!("{a:6.3?}");
1692        let a = a.ones_like();
1693        println!("{a:6.3?}");
1694        let a: Tensor<f64> = zeros([2, 2]);
1695        println!("{a:6.3?}");
1696        let a: Tensor<f64, _> = zeros(([2, 2], &DeviceCpuSerial::default()));
1697        println!("{a:6.3?}");
1698        let a = a.zeros_like();
1699        println!("{a:6.3?}");
1700    }
1701
1702    #[test]
1703    fn test_tril() {
1704        let a = arange((1, 10)).into_shape((3, 3));
1705        let b = a.view().tril();
1706        println!("{b:6.3?}");
1707        let b = triu((a, 1));
1708        println!("{b:6.3?}");
1709    }
1710}