hpt/
to_tensor.rs

1use crate::backend::Cpu;
2use crate::tensor::{DiffTensor, Tensor};
3use crate::tensor_base::_Tensor;
4use half::bf16;
5use half::f16;
6use hpt_allocator::traits::Allocator;
7use hpt_allocator::traits::AllocatorOutputRetrive;
8use hpt_allocator::Backend;
9use hpt_common::shape::shape::Shape;
10use hpt_common::strides::strides_utils::shape_to_strides;
11use hpt_common::utils::pointer::Pointer;
12use hpt_traits::ops::creation::TensorCreator;
13use hpt_traits::tensor::TensorLike;
14use num::complex::{Complex32, Complex64};
15use std::alloc::Layout;
16use std::cell::RefCell;
17use std::marker::PhantomData;
18use std::mem::ManuallyDrop;
19use std::rc::Rc;
20use std::sync::Arc;
21
22macro_rules! from_scalar {
23    ($($t:ident),*) => {
24        $(
25            impl<const DEVICE: usize, A> Into<_Tensor<$t, Cpu, DEVICE, A>> for $t where A: Allocator, A::Output: AllocatorOutputRetrive {
26                fn into(self) -> _Tensor<$t, Cpu, DEVICE, A> {
27                    let mut ret = _Tensor::<$t, Cpu, DEVICE, A>::empty(vec![1]).unwrap();
28                    ret.as_raw_mut()[0] = self;
29                    return ret;
30                }
31            }
32            impl<const DEVICE: usize, A> Into<Tensor<$t, Cpu, DEVICE, A>> for $t where A: Allocator, A::Output: AllocatorOutputRetrive {
33                fn into(self) -> Tensor<$t, Cpu, DEVICE, A> {
34                    Tensor {
35                        inner: Arc::new(self.into()),
36                    }
37                }
38            }
39        )*
40    };
41}
42
43macro_rules! impl_type_num {
44    (num, $($t:ident),*) => {
45        $(
46            impl TypeNum for $t {
47                fn type_num() -> Dtype {
48                    return map_type_num!($t);
49                }
50            }
51        )*
52    };
53
54    (vec, $($t:ident),*) => {
55        $(
56            impl<const DEVICE: usize, A> From<Vec<$t>> for _Tensor<$t, Cpu, DEVICE, A> where A: Allocator, A::Output: AllocatorOutputRetrive {
57                fn from(data: Vec<$t>) -> Self {
58                    let mut ptr = data.as_ptr() as *mut $t;
59                    let length = data.len();
60                    let res_shape = Shape::from(vec![length as i64]);
61                    let layout;
62                    let mut allocator = A::new();
63                    if (ptr as usize) % 8 == 0 {
64                        let _ = ManuallyDrop::new(data);
65                        layout = Layout::from_size_align(length * std::mem::size_of::<$t>(), 8).unwrap();
66                        allocator.insert_ptr(ptr as *mut u8, DEVICE);
67                    } else {
68                        layout = Layout::from_size_align(length * std::mem::size_of::<$t>(), 8).unwrap();
69                        let allocate_res = allocator.allocate(layout, DEVICE).unwrap();
70                        ptr = allocate_res.get_ptr() as *mut $t;
71                        unsafe {
72                            std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, length);
73                        }
74                    }
75                    let ly = hpt_common::layout::layout::Layout::new(res_shape, vec![1]);
76                    return _Tensor {
77                        #[cfg(not(feature = "bound_check"))]
78                        data: Pointer::new(ptr),
79                        #[cfg(feature = "bound_check")]
80                        data: Pointer::new(ptr, length as i64),
81                        parent: None,
82                        layout: ly,
83                        mem_layout: Arc::new(layout),
84                        backend: Backend::<Cpu>::new(ptr as u64, DEVICE, true),
85                        phantom: PhantomData,
86                    };
87                }
88            }
89            impl<const DEVICE: usize> From<Vec<$t>> for Tensor<$t, Cpu, DEVICE> {
90                fn from(data: Vec<$t>) -> Self {
91                    Tensor {
92                        inner: Arc::new(data.into()),
93                    }
94                }
95            }
96        )*
97    };
98    (ndarray, $($generic:ident),*; $($vars:ident),*; $ct:ident, $($t:ident),*) => {
99            impl<$(const $generic: usize), *, const DEVICE: usize, A> From<repeate_generic!(nested_array_type, $($generic), *; $ct)> for _Tensor<$ct, Cpu, DEVICE, A> where A: Allocator, A::Output: AllocatorOutputRetrive {
100                fn from(data: repeate_generic!(nested_array_type, $($generic), *; $ct)) -> Self {
101                    let mut vec: Vec<$ct> = Vec::with_capacity(repeate_generic!(operations, *, $($generic), *));
102                    let shape = Shape::from(vec![$($generic as i64), *]);
103
104                    repeate_generic!(iterate, data; vec; $($vars), *).for_each(|element| vec.push(element));
105                    let mut ptr = vec.as_mut_ptr();
106                    let length = repeate_generic!(mul, $($generic), *);
107                    let layout;
108                    let mut allocator = A::new();
109                    if (ptr as usize) % 8 == 0 {
110                        let _ = ManuallyDrop::new(vec);
111                        layout = Layout::from_size_align(length * std::mem::size_of::<$ct>(), 8).unwrap();
112                        allocator.insert_ptr(ptr as *mut u8, DEVICE);
113                    } else {
114                        layout = Layout::from_size_align(length * std::mem::size_of::<$ct>(), 8).unwrap();
115                        let allocate_res = allocator.allocate(layout, DEVICE).unwrap();
116                        ptr = allocate_res.get_ptr() as *mut $ct;
117                        unsafe {
118                            std::ptr::copy_nonoverlapping(vec.as_ptr(), ptr, vec.len());
119                        }
120                    }
121                    let strides = shape_to_strides(&shape);
122                    let ly = hpt_common::layout::layout::Layout::new(shape, strides);
123                    return _Tensor {
124                        #[cfg(not(feature = "bound_check"))]
125                        data: Pointer::new(ptr),
126                        #[cfg(feature = "bound_check")]
127                        data: Pointer::new(ptr, length as i64),
128                        parent: None,
129                        layout: ly,
130                        mem_layout: Arc::new(layout),
131                        backend: Backend::<Cpu>::new(ptr as u64, DEVICE, true),
132                        phantom: PhantomData,
133                    };
134                }
135            }
136            impl<$(const $generic: usize), *, const DEVICE: usize, A> From<repeate_generic!(nested_array_type, $($generic), *; $ct)> for Tensor<$ct, Cpu, DEVICE, A> where A: Allocator, A::Output: AllocatorOutputRetrive {
137                fn from(data: repeate_generic!(nested_array_type, $($generic), *; $ct)) -> Self {
138                    Tensor {
139                        inner: Arc::new(data.into()),
140                    }
141                }
142            }
143            impl_type_num!(ndarray, $($generic), *; $($vars), *; $($t),*);
144    };
145    (ndarray, $($generic:ident),*; $($vars:ident),*; $ct:ident) => {
146        impl<$(const $generic: usize), *, const DEVICE: usize, A> From<repeate_generic!(nested_array_type, $($generic), *; $ct)> for _Tensor<$ct, Cpu, DEVICE, A> where A: Allocator, A::Output: AllocatorOutputRetrive {
147            fn from(data: repeate_generic!(nested_array_type, $($generic), *; $ct)) -> Self {
148                let mut vec: Vec<$ct> = Vec::with_capacity(repeate_generic!(operations, *, $($generic), *));
149                let shape = Shape::from(vec![$($generic as i64), *]);
150
151                repeate_generic!(iterate, data; vec; $($vars), *).for_each(|element| vec.push(element));
152                let mut ptr = vec.as_mut_ptr();
153                let length = repeate_generic!(mul, $($generic), *);
154                let layout;
155                let mut allocator = A::new();
156                if (ptr as usize) % 8 == 0 {
157                    let _ = ManuallyDrop::new(vec);
158                    layout = Layout::from_size_align(length * std::mem::size_of::<$ct>(), 8).unwrap();
159                    allocator.insert_ptr(ptr as *mut u8, DEVICE);
160                } else {
161                    layout = Layout::from_size_align(length * std::mem::size_of::<$ct>(), 8).unwrap();
162                    let allocate_res = allocator.allocate(layout, DEVICE).unwrap();
163                    ptr = allocate_res.get_ptr() as *mut $ct;
164                    unsafe {
165                        std::ptr::copy_nonoverlapping(vec.as_ptr(), ptr, vec.len());
166                    }
167                }
168                let strides = shape_to_strides(&shape);
169
170                let ly = hpt_common::layout::layout::Layout::new(shape, strides);
171                return _Tensor {
172                    #[cfg(not(feature = "bound_check"))]
173                    data: Pointer::new(ptr),
174                    #[cfg(feature = "bound_check")]
175                    data: Pointer::new(ptr, length as i64),
176                    parent: None,
177                    layout: ly,
178                    mem_layout: Arc::new(layout),
179                    backend: Backend::<Cpu>::new(ptr as u64, DEVICE, true),
180                    phantom: PhantomData,
181                };
182            }
183        }
184        impl<$(const $generic: usize), *, const DEVICE: usize, A> From<repeate_generic!(nested_array_type, $($generic), *; $ct)> for Tensor<$ct, Cpu, DEVICE, A> where A: Allocator, A::Output: AllocatorOutputRetrive {
185            fn from(data: repeate_generic!(nested_array_type, $($generic), *; $ct)) -> Self {
186                Tensor {
187                    inner: Arc::new(data.into()),
188                }
189            }
190        }
191    };
192
193    (
194        ndarray_source_target,
195        $source:ident,
196        $($generic:ident),*;
197        $($vars:ident),*;
198        $ct:ident,
199        $($t:ident),*
200    ) => {
201        impl<$(const $generic: usize), *, const DEVICE: usize, A> From<repeate_generic!(nested_array_type, $($generic), *; $source)> for Tensor<$ct, Cpu, DEVICE, A> where A: Allocator, A::Output: AllocatorOutputRetrive {
202            fn from(data: repeate_generic!(nested_array_type, $($generic), *; $source)) -> Self {
203                Tensor {
204                    inner: Arc::new(data.into()),
205                }
206            }
207        }
208        impl_type_num!(ndarray_source_target, $source, $($generic), *; $($vars), *; $($t),*);
209    };
210    (ndarray_source_target, $source:ident, $($generic:ident),*; $($vars:ident),*; $ct:ident) => {
211    impl<$(const $generic: usize), *, const DEVICE: usize, A> From<repeate_generic!(nested_array_type, $($generic), *; $source)> for _Tensor<$ct, Cpu, DEVICE, A> where A: Allocator, A::Output: AllocatorOutputRetrive {
212        fn from(data: repeate_generic!(nested_array_type, $($generic), *; $source)) -> Self {
213            let mut vec: Vec<$ct> = Vec::with_capacity(repeate_generic!(operations, *, $($generic), *));
214            let shape = Shape::from(vec![$($generic as i64), *]);
215
216            repeate_generic!(iterate, data; vec; $($vars), *).for_each(|element| vec.push(element.into()));
217            let mut ptr = vec.as_mut_ptr();
218            let length = repeate_generic!(mul, $($generic), *);
219            let layout;
220            let mut allocator = A::new();
221            if (ptr as usize) % 8 == 0 {
222                let _ = ManuallyDrop::new(vec);
223                layout = Layout::from_size_align(length * std::mem::size_of::<$ct>(), 8).unwrap();
224                allocator.insert_ptr(ptr as *mut u8, DEVICE);
225            } else {
226                layout = Layout::from_size_align(length * std::mem::size_of::<$ct>(), 8).unwrap();
227                let allocate_res = allocator.allocate(layout, DEVICE).unwrap();
228                ptr = allocate_res.get_ptr() as *mut $ct;
229                unsafe {
230                    std::ptr::copy_nonoverlapping(vec.as_ptr(), ptr, vec.len());
231                }
232            }
233            let strides = shape_to_strides(&shape);
234
235            let ly = hpt_common::layout::layout::Layout::new(shape, strides);
236            return _Tensor {
237                #[cfg(not(feature = "bound_check"))]
238                data: Pointer::new(ptr),
239                #[cfg(feature = "bound_check")]
240                data: Pointer::new(ptr, length as i64),
241                parent: None,
242                layout: ly,
243                mem_layout: Arc::new(layout),
244                backend: Backend::<Cpu>::new(ptr as u64, DEVICE, true),
245                phantom: PhantomData,
246            };
247        }
248    }
249    impl<$(const $generic: usize), *, const DEVICE: usize, A> From<repeate_generic!(nested_array_type, $($generic), *; $source)> for Tensor<$ct, Cpu, DEVICE, A> where A: Allocator, A::Output: AllocatorOutputRetrive {
250        fn from(data: repeate_generic!(nested_array_type, $($generic), *; $source)) -> Self {
251            Tensor {
252                    inner: Arc::new(data.into()),
253                }
254            }
255        }
256    };
257
258    (ndarray_ref, $($generic:ident),*; $($vars:ident),*; $ct:ident, $($t:ident),*) => {
259        impl<$(const $generic: usize), *, const DEVICE: usize, A> From<&repeate_generic!(nested_array_type, $($generic), *; $ct)> for _Tensor<$ct, Cpu, DEVICE, A> where A: Allocator, A::Output: AllocatorOutputRetrive {
260            fn from(data: &repeate_generic!(nested_array_type, $($generic), *; $ct)) -> Self {
261                let mut vec: Vec<$ct> = Vec::with_capacity(repeate_generic!(operations, *, $($generic), *));
262                let shape = Shape::from(vec![$($generic as i64), *]);
263
264                repeate_generic!(iterate, data; vec; $($vars), *).for_each(|element| vec.push(*element));
265                let mut ptr = vec.as_mut_ptr();
266                let length = repeate_generic!(mul, $($generic), *);
267                let layout;
268                let mut allocator = A::new();
269                if (ptr as usize) % 8 == 0 {
270                    let _ = ManuallyDrop::new(vec);
271                    layout = Layout::from_size_align(length * std::mem::size_of::<$ct>(), 8).unwrap();
272                    allocator.insert_ptr(ptr as *mut u8, DEVICE);
273                } else {
274                    layout = Layout::from_size_align(length * std::mem::size_of::<$ct>(), 8).unwrap();
275                    let allocate_res = allocator.allocate(layout, DEVICE).unwrap();
276                    ptr = allocate_res.get_ptr() as *mut $ct;
277                    unsafe {
278                        std::ptr::copy_nonoverlapping(vec.as_ptr(), ptr, vec.len());
279                    }
280                }
281                let strides = shape_to_strides(&shape);
282
283                let ly = hpt_common::layout::layout::Layout::new(shape, strides);
284                return _Tensor {
285                    #[cfg(not(feature = "bound_check"))]
286                    data: Pointer::new(ptr),
287                    #[cfg(feature = "bound_check")]
288                    data: Pointer::new(ptr, length as i64),
289                    parent: None,
290                    layout: ly,
291                    mem_layout: Arc::new(layout),
292                    backend: Backend::<Cpu>::new(ptr as u64, DEVICE, true),
293                    phantom: PhantomData,
294                };
295            }
296        }
297        impl<$(const $generic: usize), *, const DEVICE: usize, A> From<&repeate_generic!(nested_array_type, $($generic), *; $ct)> for Tensor<$ct, Cpu, DEVICE, A> where A: Allocator, A::Output: AllocatorOutputRetrive {
298            fn from(data: &repeate_generic!(nested_array_type, $($generic), *; $ct)) -> Self {
299                Tensor {
300                    inner: Arc::new(data.into()),
301                }
302            }
303        }
304        impl_type_num!(ndarray_ref, $($generic), *; $($vars), *; $($t),*);
305    };
306    (ndarray_ref, $($generic:ident),*; $($vars:ident),*; $ct:ident) => {
307        impl<$(const $generic: usize), *, const DEVICE: usize, A> From<&repeate_generic!(nested_array_type, $($generic), *; $ct)> for _Tensor<$ct, Cpu, DEVICE, A> where A: Allocator, A::Output: AllocatorOutputRetrive {
308            fn from(data: &repeate_generic!(nested_array_type, $($generic), *; $ct)) -> Self {
309                let mut vec: Vec<$ct> = Vec::with_capacity(repeate_generic!(operations, *, $($generic), *));
310                let shape = Shape::from(vec![$($generic as i64), *]);
311
312                repeate_generic!(iterate, data; vec; $($vars), *).for_each(|element| vec.push(*element));
313                let mut ptr = vec.as_mut_ptr();
314                let length = repeate_generic!(mul, $($generic), *);
315                let layout;
316                let mut allocator = A::new();
317                if (ptr as usize) % 8 == 0 {
318                    let _ = ManuallyDrop::new(vec);
319                    layout = Layout::from_size_align(length * std::mem::size_of::<$ct>(), 8).unwrap();
320                    allocator.insert_ptr(ptr as *mut u8, DEVICE);
321                } else {
322                    layout = Layout::from_size_align(length * std::mem::size_of::<$ct>(), 8).unwrap();
323                    let allocate_res = allocator.allocate(layout, DEVICE).unwrap();
324                    ptr = allocate_res.get_ptr() as *mut $ct;
325                    unsafe {
326                        std::ptr::copy_nonoverlapping(vec.as_ptr(), ptr, vec.len());
327                    }
328                }
329                let strides = shape_to_strides(&shape);
330
331                let ly = hpt_common::layout::layout::Layout::new(shape, strides);
332                return _Tensor {
333                    #[cfg(not(feature = "bound_check"))]
334                    data: Pointer::new(ptr),
335                    #[cfg(feature = "bound_check")]
336                    data: Pointer::new(ptr, length as i64),
337                    parent: None,
338                    layout: ly,
339                    mem_layout: Arc::new(layout),
340                    backend: Backend::<Cpu>::new(ptr as u64, DEVICE, true),
341                    phantom: PhantomData,
342                };
343            }
344        }
345        impl<$(const $generic: usize), *, const DEVICE: usize> From<&repeate_generic!(nested_array_type, $($generic), *; $ct)> for Tensor<$ct, Cpu, DEVICE> {
346            fn from(data: &repeate_generic!(nested_array_type, $($generic), *; $ct)) -> Self {
347                Tensor {
348                    inner: Arc::new(data.into()),
349                }
350            }
351        }
352    };
353}
354
355/// This macro is used to generate the nested array type
356macro_rules! repeate_generic {
357    (const, $($t:ident),*) => {
358        impl<$(const $t: usize), *>
359    };
360    (nested_array, $n:expr, $($t:expr),*; $data_type:ident) => {
361        [repeate_generic!(nested_array, $($t), *; $data_type);$n];
362    };
363    (nested_array, $t:expr; $data_type:ident) => {
364        [$data_type; $t]
365    };
366    (nested_array_type, $n:expr, $($t:expr),*; $data_type:ident) => {
367        [repeate_generic!(nested_array_type, $($t), *; $data_type);$n]
368    };
369    (nested_array_type, $t:expr; $data_type:ident) => {
370        [$data_type; $t]
371    };
372    (operations, $op:tt, $n:expr, $($t:expr),*) => {
373        $n $op repeate_generic!(operations, $op, $($t), *)
374    };
375    (operations, $op:tt, $n:expr) => {
376        $n
377    };
378    (iterate, $data:ident; $vec:ident; $n:ident, $($t:ident),*) => {
379        $data.into_iter().flat_map(|$n| repeate_generic!(iterate, $vec; $n;; $($t), *))
380    };
381    (iterate, $data:ident; $vec:ident; $n:ident) => {
382        $data.into_iter().flat_map(|$n| repeate_generic!(iterate, $vec; $n;;))
383    };
384    (iterate, $vec:ident; $n:ident; ; $n2:ident, $($t:ident),*) => {
385        $n.into_iter().flat_map(|$n2| repeate_generic!(iterate, $vec; $n2;; $($t), *))
386    };
387    (iterate, $vec:ident; $n:ident; ; $n2:ident) => {
388        $n.into_iter().flat_map(|$n2| repeate_generic!(iterate, $vec; $n2;;))
389    };
390    (iterate, $vec:ident; $n:ident; ;) => {
391        $n.into_iter()
392    };
393    (iterate, $data:ident; $vec:ident;) => {
394        $data.into_iter()
395    };
396    (mul, $n:expr, $($t:expr),*) => {
397        $n * repeate_generic!(mul, $($t), *)
398    };
399    (mul, $n:expr) => {
400        $n
401    };
402}
403
404from_scalar!(bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, bf16, f32, f64, Complex32, Complex64);
405impl_type_num!(
406    vec, bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64
407); // prettier-ignore
408impl_type_num!(ndarray, N; ; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
409impl_type_num!(ndarray_ref, N; ; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
410impl_type_num!(ndarray, N, M; i; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
411impl_type_num!(ndarray_ref, N, M; i; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
412impl_type_num!(ndarray, N, M, O; i, j; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
413impl_type_num!(ndarray_ref, N, M, O; i, j; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
414impl_type_num!(ndarray, N, M, O, P; i, j, k; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
415impl_type_num!(ndarray_ref, N, M, O, P; i, j, k; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
416impl_type_num!(ndarray, N, M, O, P, Q; i, j, k, l; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
417impl_type_num!(ndarray_ref, N, M, O, P, Q; i, j, k, l; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
418impl_type_num!(ndarray, N, M, O, P, Q, R; i, j, k, l, m; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
419impl_type_num!(ndarray_ref, N, M, O, P, Q, R; i, j, k, l, m; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
420impl_type_num!(ndarray, N, M, O, P, Q, R, S; i, j, k, l, m, n; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
421impl_type_num!(ndarray_ref, N, M, O, P, Q, R, S; i, j, k, l, m, n; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
422impl_type_num!(ndarray, N, M, O, P, Q, R, S, T; i, j, k, l, m, n, o; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
423impl_type_num!(ndarray_ref, N, M, O, P, Q, R, S, T; i, j, k, l, m, n, o; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
424impl_type_num!(ndarray_source_target, f32, N; ; Complex32);
425impl_type_num!(ndarray_source_target, f64, N; ; Complex64);
426impl_type_num!(ndarray_source_target, f32, N, M; i; Complex32);
427impl_type_num!(ndarray_source_target, f64, N, M; i; Complex64);
428impl_type_num!(ndarray_source_target, f32, N, M, O; i, j; Complex32);
429impl_type_num!(ndarray_source_target, f64, N, M, O; i, j; Complex64);
430impl_type_num!(ndarray_source_target, f32, N, M, O, P; i, j, k; Complex32);
431impl_type_num!(ndarray_source_target, f64, N, M, O, P; i, j, k; Complex64);
432impl_type_num!(ndarray_source_target, f32, N, M, O, P, Q; i, j, k, l; Complex32);
433impl_type_num!(ndarray_source_target, f64, N, M, O, P, Q; i, j, k, l; Complex64);
434impl_type_num!(ndarray_source_target, f32, N, M, O, P, Q, R; i, j, k, l, m; Complex32);
435impl_type_num!(ndarray_source_target, f64, N, M, O, P, Q, R; i, j, k, l, m; Complex64);
436impl_type_num!(ndarray_source_target, f32, N, M, O, P, Q, R, S; i, j, k, l, m, n; Complex32);
437impl_type_num!(ndarray_source_target, f64, N, M, O, P, Q, R, S; i, j, k, l, m, n; Complex64);
438impl_type_num!(ndarray_source_target, f32, N, M, O, P, Q, R, S, T; i, j, k, l, m, n, o; Complex32);
439impl_type_num!(ndarray_source_target, f64, N, M, O, P, Q, R, S, T; i, j, k, l, m, n, o; Complex64);
440
441impl<T, const DEVICE: usize, Al> Tensor<T, Cpu, DEVICE, Al>
442where
443    Al: Allocator,
444{
445    /// Creates a new tensor from the provided data.
446    pub fn new<A>(data: A) -> Self
447    where
448        A: Into<Tensor<T, Cpu, DEVICE, Al>>,
449    {
450        data.into()
451    }
452}
453
454impl<T, const DEVICE: usize, Al> DiffTensor<T, Cpu, DEVICE, Al>
455where
456    Al: Allocator,
457{
458    /// Creates a new differentiable tensor from the provided data.
459    pub fn new<A>(data: A) -> Self
460    where
461        A: Into<Tensor<T, Cpu, DEVICE, Al>>,
462    {
463        let ret = data.into();
464        DiffTensor {
465            inner: ret,
466            grad: Rc::new(RefCell::new(None)),
467            out_degree: Rc::new(RefCell::new(0)),
468            backward: Rc::new(RefCell::new(move |_| Ok(true))),
469        }
470    }
471}