rstsr_core/tensor/
asarray.rs

1//! Implementation of function `asarray`.
2
3use crate::prelude_dev::*;
4use core::mem::ManuallyDrop;
5use num::complex::{Complex32, Complex64};
6
7pub trait AsArrayAPI<Inp> {
8    type Out;
9
10    fn asarray_f(self) -> Result<Self::Out>;
11
12    fn asarray(self) -> Self::Out
13    where
14        Self: Sized,
15    {
16        Self::asarray_f(self).unwrap()
17    }
18}
19
20/// Convert the input to an array.
21///
22/// This function takes kinds of input and converts them to an array. Please
23/// refer to trait implementations of [`AsArrayAPI`].
24///
25/// # See also
26///
27/// [Python array API: `asarray`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.asarray.html)
28pub fn asarray<Args, Inp>(param: Args) -> Args::Out
29where
30    Args: AsArrayAPI<Inp>,
31{
32    return AsArrayAPI::asarray(param);
33}
34
35pub fn asarray_f<Args, Inp>(param: Args) -> Result<Args::Out>
36where
37    Args: AsArrayAPI<Inp>,
38{
39    return AsArrayAPI::asarray_f(param);
40}
41
42/* #region tensor input */
43
44impl<R, T, B, D> AsArrayAPI<()> for (&TensorAny<R, T, B, D>, TensorIterOrder)
45where
46    R: DataAPI<Data = B::Raw>,
47    T: Clone,
48    D: DimAPI,
49    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
50{
51    type Out = Tensor<T, B, D>;
52
53    fn asarray_f(self) -> Result<Self::Out> {
54        let (input, order) = self;
55        let device = input.device();
56        let layout_a = input.layout();
57        let layout_c = layout_for_array_copy(layout_a, order)?;
58        let mut storage_c = unsafe { device.empty_impl(layout_c.size())? };
59        device.assign(storage_c.raw_mut(), &layout_c, input.raw(), layout_a)?;
60        let tensor = unsafe { Tensor::new_unchecked(storage_c, layout_c) };
61        return Ok(tensor);
62    }
63}
64
65impl<R, T, B, D> AsArrayAPI<()> for &TensorAny<R, T, B, D>
66where
67    R: DataAPI<Data = B::Raw>,
68    T: Clone,
69    D: DimAPI,
70    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
71{
72    type Out = Tensor<T, B, D>;
73
74    fn asarray_f(self) -> Result<Self::Out> {
75        asarray_f((self, TensorIterOrder::default()))
76    }
77}
78
79impl<T, B, D> AsArrayAPI<()> for (Tensor<T, B, D>, TensorIterOrder)
80where
81    T: Clone,
82    D: DimAPI,
83    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
84{
85    type Out = Tensor<T, B, D>;
86
87    fn asarray_f(self) -> Result<Self::Out> {
88        let (input, order) = self;
89        let storage_a = input.storage();
90        let layout_a = input.layout();
91        let device = storage_a.device();
92        let layout_c = layout_for_array_copy(layout_a, order)?;
93        if layout_c == *layout_a {
94            return Ok(input);
95        } else {
96            let mut storage_c = unsafe { device.empty_impl(layout_c.size())? };
97            device.assign(storage_c.raw_mut(), &layout_c, storage_a.raw(), layout_a)?;
98            let tensor = unsafe { Tensor::new_unchecked(storage_c, layout_c) };
99            return Ok(tensor);
100        }
101    }
102}
103
104impl<T, B, D> AsArrayAPI<()> for Tensor<T, B, D>
105where
106    T: Clone,
107    D: DimAPI,
108    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
109{
110    type Out = Tensor<T, B, D>;
111
112    fn asarray_f(self) -> Result<Self::Out> {
113        asarray_f((self, TensorIterOrder::default()))
114    }
115}
116
117/* #endregion */
118
119/* #region vec-like input */
120
121impl<T, B> AsArrayAPI<()> for (Vec<T>, &B)
122where
123    B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
124{
125    type Out = Tensor<T, B, IxD>;
126
127    fn asarray_f(self) -> Result<Self::Out> {
128        let (input, device) = self;
129        let layout = vec![input.len()].c();
130        let storage = device.outof_cpu_vec(input)?;
131        let tensor = unsafe { Tensor::new_unchecked(storage, layout) };
132        return Ok(tensor);
133    }
134}
135
136impl<T, B, D> AsArrayAPI<D> for (Vec<T>, Layout<D>, &B)
137where
138    D: DimAPI,
139    B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
140{
141    type Out = Tensor<T, B, IxD>;
142
143    fn asarray_f(self) -> Result<Self::Out> {
144        let (input, layout, device) = self;
145        rstsr_assert_eq!(
146            layout.bounds_index()?,
147            (0, layout.size()),
148            InvalidLayout,
149            "This constructor assumes compact memory layout."
150        )?;
151        rstsr_assert_eq!(
152            layout.size(),
153            input.len(),
154            InvalidLayout,
155            "This constructor assumes that the layout size is equal to the input size."
156        )?;
157        let storage = device.outof_cpu_vec(input)?;
158        let tensor = unsafe { Tensor::new_unchecked(storage, layout.into_dim()?) };
159        return Ok(tensor);
160    }
161}
162
163impl<T, B, D> AsArrayAPI<D> for (Vec<T>, D, &B)
164where
165    D: DimAPI,
166    B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
167{
168    type Out = Tensor<T, B, IxD>;
169
170    fn asarray_f(self) -> Result<Self::Out> {
171        let (input, shape, device) = self;
172        let default_order = device.default_order();
173        let layout = match default_order {
174            RowMajor => shape.c(),
175            ColMajor => shape.f(),
176        };
177        asarray_f((input, layout, device))
178    }
179}
180
181impl<T> AsArrayAPI<()> for Vec<T>
182where
183    T: Clone,
184{
185    type Out = Tensor<T, DeviceCpu, IxD>;
186
187    fn asarray_f(self) -> Result<Self::Out> {
188        asarray_f((self, &DeviceCpu::default()))
189    }
190}
191
192#[duplicate_item(L; [D]; [Layout<D>])]
193impl<T, D> AsArrayAPI<D> for (Vec<T>, L)
194where
195    T: Clone,
196    D: DimAPI,
197{
198    type Out = Tensor<T, DeviceCpu, IxD>;
199
200    fn asarray_f(self) -> Result<Self::Out> {
201        let (input, layout) = self;
202        asarray_f((input, layout, &DeviceCpu::default()))
203    }
204}
205
206impl<T> From<Vec<T>> for Tensor<T, DeviceCpu, IxD>
207where
208    T: Clone,
209{
210    fn from(input: Vec<T>) -> Self {
211        asarray_f(input).unwrap()
212    }
213}
214
215/* #endregion */
216
217/* #region slice-like input */
218
219impl<'a, T, B, D> AsArrayAPI<D> for (&'a [T], Layout<D>, &B)
220where
221    T: Clone,
222    B: DeviceAPI<T, Raw = Vec<T>>,
223    D: DimAPI,
224{
225    type Out = TensorView<'a, T, B, IxD>;
226
227    fn asarray_f(self) -> Result<Self::Out> {
228        let (input, layout, device) = self;
229        rstsr_assert_eq!(
230            layout.bounds_index()?,
231            (0, layout.size()),
232            InvalidLayout,
233            "This constructor assumes compact memory layout."
234        )?;
235        rstsr_assert_eq!(
236            layout.size(),
237            input.len(),
238            InvalidLayout,
239            "This constructor assumes that the layout size is equal to the input size."
240        )?;
241        let ptr = input.as_ptr();
242        let len = input.len();
243        let raw = unsafe {
244            let ptr = ptr as *mut T;
245            Vec::from_raw_parts(ptr, len, len)
246        };
247        let device = device.clone();
248        let data = DataRef::from_manually_drop(ManuallyDrop::new(raw));
249        let storage = Storage::new(data, device);
250        let tensor = unsafe { TensorView::new_unchecked(storage, layout.into_dim()?) };
251        return Ok(tensor);
252    }
253}
254
255impl<'a, T, B, D> AsArrayAPI<D> for (&'a [T], D, &B)
256where
257    T: Clone,
258    B: DeviceAPI<T, Raw = Vec<T>>,
259    D: DimAPI,
260{
261    type Out = TensorView<'a, T, B, IxD>;
262
263    fn asarray_f(self) -> Result<Self::Out> {
264        let (input, shape, device) = self;
265        let default_order = device.default_order();
266        let layout = match default_order {
267            RowMajor => shape.c(),
268            ColMajor => shape.f(),
269        };
270        asarray_f((input, layout, device))
271    }
272}
273
274impl<'a, T, B> AsArrayAPI<()> for (&'a [T], &B)
275where
276    T: Clone,
277    B: DeviceAPI<T, Raw = Vec<T>>,
278{
279    type Out = TensorView<'a, T, B, IxD>;
280
281    fn asarray_f(self) -> Result<Self::Out> {
282        let (input, device) = self;
283        let layout = vec![input.len()].c();
284        let device = device.clone();
285
286        let ptr = input.as_ptr();
287        let len = input.len();
288        let raw = unsafe {
289            let ptr = ptr as *mut T;
290            Vec::from_raw_parts(ptr, len, len)
291        };
292        let data = DataRef::from_manually_drop(ManuallyDrop::new(raw));
293        let storage = Storage::new(data, device);
294        let tensor = unsafe { TensorView::new_unchecked(storage, layout) };
295        return Ok(tensor);
296    }
297}
298
299#[duplicate_item(L; [D]; [Layout<D>])]
300impl<'a, T, D> AsArrayAPI<D> for (&'a [T], L)
301where
302    T: Clone,
303    D: DimAPI,
304{
305    type Out = TensorView<'a, T, DeviceCpu, IxD>;
306
307    fn asarray_f(self) -> Result<Self::Out> {
308        let (input, layout) = self;
309        asarray_f((input, layout, &DeviceCpu::default()))
310    }
311}
312
313impl<'a, T> AsArrayAPI<()> for &'a [T]
314where
315    T: Clone,
316{
317    type Out = TensorView<'a, T, DeviceCpu, IxD>;
318
319    fn asarray_f(self) -> Result<Self::Out> {
320        asarray_f((self, &DeviceCpu::default()))
321    }
322}
323
324#[duplicate_item(L; [D]; [Layout<D>])]
325impl<'a, T, B, D> AsArrayAPI<D> for (&'a Vec<T>, L, &B)
326where
327    T: Clone,
328    B: DeviceAPI<T, Raw = Vec<T>> + 'a,
329    D: DimAPI,
330{
331    type Out = TensorView<'a, T, B, IxD>;
332
333    fn asarray_f(self) -> Result<Self::Out> {
334        let (input, layout, device) = self;
335        asarray_f((input.as_slice(), layout, device))
336    }
337}
338
339impl<'a, T, B> AsArrayAPI<()> for (&'a Vec<T>, &B)
340where
341    T: Clone,
342    B: DeviceAPI<T, Raw = Vec<T>>,
343{
344    type Out = TensorView<'a, T, B, IxD>;
345
346    fn asarray_f(self) -> Result<Self::Out> {
347        let (input, device) = self;
348        asarray_f((input.as_slice(), device))
349    }
350}
351
352#[duplicate_item(L; [D]; [Layout<D>])]
353impl<'a, T, D> AsArrayAPI<D> for (&'a Vec<T>, L)
354where
355    T: Clone,
356    D: DimAPI,
357{
358    type Out = TensorView<'a, T, DeviceCpu, IxD>;
359
360    fn asarray_f(self) -> Result<Self::Out> {
361        let (input, layout) = self;
362        asarray_f((input.as_slice(), layout, &DeviceCpu::default()))
363    }
364}
365
366impl<'a, T> AsArrayAPI<()> for &'a Vec<T>
367where
368    T: Clone,
369{
370    type Out = TensorView<'a, T, DeviceCpu, IxD>;
371
372    fn asarray_f(self) -> Result<Self::Out> {
373        asarray_f((self.as_slice(), &DeviceCpu::default()))
374    }
375}
376
377impl<'a, T> From<&'a [T]> for TensorView<'a, T, DeviceCpu, IxD>
378where
379    T: Clone,
380{
381    fn from(input: &'a [T]) -> Self {
382        asarray(input)
383    }
384}
385
386impl<'a, T> From<&'a Vec<T>> for TensorView<'a, T, DeviceCpu, IxD>
387where
388    T: Clone,
389{
390    fn from(input: &'a Vec<T>) -> Self {
391        asarray(input)
392    }
393}
394
395/* #endregion */
396
397/* #region slice-like mutable input */
398
399impl<'a, T, B, D> AsArrayAPI<D> for (&'a mut [T], Layout<D>, &B)
400where
401    T: Clone,
402    B: DeviceAPI<T, Raw = Vec<T>>,
403    D: DimAPI,
404{
405    type Out = TensorMut<'a, T, B, IxD>;
406
407    fn asarray_f(self) -> Result<Self::Out> {
408        let (input, layout, device) = self;
409        rstsr_assert_eq!(
410            layout.bounds_index()?,
411            (0, layout.size()),
412            InvalidLayout,
413            "This constructor assumes compact memory layout."
414        )?;
415        rstsr_assert_eq!(
416            layout.size(),
417            input.len(),
418            InvalidLayout,
419            "This constructor assumes that the layout size is equal to the input size."
420        )?;
421        let ptr = input.as_ptr();
422        let len = input.len();
423        let raw = unsafe {
424            let ptr = ptr as *mut T;
425            Vec::from_raw_parts(ptr, len, len)
426        };
427        let device = device.clone();
428        let data = DataMut::from_manually_drop(ManuallyDrop::new(raw));
429        let storage = Storage::new(data, device);
430        let tensor = unsafe { TensorMut::new_unchecked(storage, layout.into_dim()?) };
431        return Ok(tensor);
432    }
433}
434
435impl<'a, T, B, D> AsArrayAPI<D> for (&'a mut [T], D, &B)
436where
437    T: Clone,
438    B: DeviceAPI<T, Raw = Vec<T>>,
439    D: DimAPI,
440{
441    type Out = TensorMut<'a, T, B, IxD>;
442
443    fn asarray_f(self) -> Result<Self::Out> {
444        let (input, shape, device) = self;
445        let default_order = device.default_order();
446        let layout = match default_order {
447            RowMajor => shape.c(),
448            ColMajor => shape.f(),
449        };
450        asarray_f((input, layout, device))
451    }
452}
453
454impl<'a, T, B> AsArrayAPI<()> for (&'a mut [T], &B)
455where
456    T: Clone,
457    B: DeviceAPI<T, Raw = Vec<T>>,
458{
459    type Out = TensorMut<'a, T, B, IxD>;
460
461    fn asarray_f(self) -> Result<Self::Out> {
462        let (input, device) = self;
463        let layout = [input.len()].c();
464        let device = device.clone();
465
466        let ptr = input.as_ptr();
467        let len = input.len();
468        let raw = unsafe {
469            let ptr = ptr as *mut T;
470            Vec::from_raw_parts(ptr, len, len)
471        };
472        let data = DataMut::from_manually_drop(ManuallyDrop::new(raw));
473        let storage = Storage::new(data, device);
474        let tensor = unsafe { TensorMut::new_unchecked(storage, layout.into_dim()?) };
475        return Ok(tensor);
476    }
477}
478
479#[duplicate_item(L; [D]; [Layout<D>])]
480impl<'a, T, D> AsArrayAPI<D> for (&'a mut [T], L)
481where
482    T: Clone,
483    D: DimAPI,
484{
485    type Out = TensorMut<'a, T, DeviceCpu, IxD>;
486
487    fn asarray_f(self) -> Result<Self::Out> {
488        let (input, layout) = self;
489        asarray_f((input, layout, &DeviceCpu::default()))
490    }
491}
492
493impl<'a, T> AsArrayAPI<()> for &'a mut [T]
494where
495    T: Clone,
496{
497    type Out = TensorMut<'a, T, DeviceCpu, IxD>;
498
499    fn asarray_f(self) -> Result<Self::Out> {
500        asarray_f((self, &DeviceCpu::default()))
501    }
502}
503
504#[duplicate_item(L; [D]; [Layout<D>])]
505impl<'a, T, B, D> AsArrayAPI<D> for (&'a mut Vec<T>, L, &B)
506where
507    T: Clone,
508    B: DeviceAPI<T, Raw = Vec<T>>,
509    D: DimAPI,
510{
511    type Out = TensorMut<'a, T, B, IxD>;
512
513    fn asarray_f(self) -> Result<Self::Out> {
514        let (input, layout, device) = self;
515        asarray_f((input.as_mut_slice(), layout, device))
516    }
517}
518
519impl<'a, T, B> AsArrayAPI<()> for (&'a mut Vec<T>, &B)
520where
521    T: Clone,
522    B: DeviceAPI<T, Raw = Vec<T>>,
523{
524    type Out = TensorMut<'a, T, B, IxD>;
525
526    fn asarray_f(self) -> Result<Self::Out> {
527        let (input, device) = self;
528        asarray_f((input.as_mut_slice(), device))
529    }
530}
531
532#[duplicate_item(L; [D]; [Layout<D>])]
533impl<'a, T, D> AsArrayAPI<D> for (&'a mut Vec<T>, L)
534where
535    T: Clone,
536    D: DimAPI,
537{
538    type Out = TensorMut<'a, T, DeviceCpu, IxD>;
539
540    fn asarray_f(self) -> Result<Self::Out> {
541        let (input, layout) = self;
542        asarray_f((input.as_mut_slice(), layout, &DeviceCpu::default()))
543    }
544}
545
546impl<'a, T> AsArrayAPI<()> for &'a mut Vec<T>
547where
548    T: Clone,
549{
550    type Out = TensorMut<'a, T, DeviceCpu, IxD>;
551
552    fn asarray_f(self) -> Result<Self::Out> {
553        asarray_f((self.as_mut_slice(), &DeviceCpu::default()))
554    }
555}
556
557impl<'a, T> From<&'a mut [T]> for TensorMut<'a, T, DeviceCpu, IxD>
558where
559    T: Clone,
560{
561    fn from(input: &'a mut [T]) -> Self {
562        asarray(input)
563    }
564}
565
566impl<'a, T> From<&'a mut Vec<T>> for TensorMut<'a, T, DeviceCpu, IxD>
567where
568    T: Clone,
569{
570    fn from(input: &'a mut Vec<T>) -> Self {
571        asarray(input)
572    }
573}
574
575/* #endregion */
576
577/* #region scalar input */
578
579macro_rules! impl_asarray_scalar {
580    ($($t:ty),*) => {
581        $(
582            impl<B> AsArrayAPI<()> for ($t, &B)
583            where
584                B: DeviceAPI<$t> + DeviceCreationAnyAPI<$t>,
585            {
586                type Out = Tensor<$t, B, IxD>;
587
588                fn asarray_f(self) -> Result<Self::Out> {
589                    let (input, device) = self;
590                    let layout = Layout::new(vec![], vec![], 0)?;
591                    let storage = device.outof_cpu_vec(vec![input])?;
592                    let tensor = unsafe { Tensor::new_unchecked(storage, layout) };
593                    return Ok(tensor);
594                }
595            }
596
597            impl AsArrayAPI<()> for $t {
598                type Out = Tensor<$t, DeviceCpu, IxD>;
599
600                fn asarray_f(self) -> Result<Self::Out> {
601                    asarray_f((self, &DeviceCpu::default()))
602                }
603            }
604        )*
605    };
606}
607
608impl_asarray_scalar!(i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64, Complex32, Complex64);
609
610/* #endregion */
611
612#[cfg(test)]
613mod tests {
614    use super::*;
615
616    #[test]
617    fn test_asarray() {
618        let input = vec![1, 2, 3];
619        let tensor = asarray_f(input).unwrap();
620        println!("{tensor:?}");
621        let input = [1, 2, 3];
622        let tensor = asarray_f(input.as_ref()).unwrap();
623        println!("{tensor:?}");
624
625        let input = vec![1, 2, 3];
626        let tensor = asarray_f(&input).unwrap();
627        println!("{:?}", tensor.raw().as_ptr());
628        println!("{tensor:?}");
629
630        let tensor = asarray_f((&tensor, TensorIterOrder::K)).unwrap();
631        println!("{tensor:?}");
632
633        let tensor = asarray_f((tensor, TensorIterOrder::K)).unwrap();
634        println!("{tensor:?}");
635    }
636
637    #[test]
638    fn test_asarray_scalar() {
639        let tensor = asarray_f(1).unwrap();
640        println!("{tensor:?}");
641        let tensor = asarray_f((Complex64::new(0., 1.), &DeviceCpuSerial::default())).unwrap();
642        println!("{tensor:?}");
643    }
644}