1use crate::backend::Cpu;
2use crate::tensor::{DiffTensor, Tensor};
3use crate::tensor_base::_Tensor;
4use half::bf16;
5use half::f16;
6use hpt_allocator::traits::Allocator;
7use hpt_allocator::traits::AllocatorOutputRetrive;
8use hpt_allocator::Backend;
9use hpt_common::shape::shape::Shape;
10use hpt_common::strides::strides_utils::shape_to_strides;
11use hpt_common::utils::pointer::Pointer;
12use hpt_traits::ops::creation::TensorCreator;
13use hpt_traits::tensor::TensorLike;
14use num::complex::{Complex32, Complex64};
15use std::alloc::Layout;
16use std::cell::RefCell;
17use std::marker::PhantomData;
18use std::mem::ManuallyDrop;
19use std::rc::Rc;
20use std::sync::Arc;
21
22macro_rules! from_scalar {
23 ($($t:ident),*) => {
24 $(
25 impl<const DEVICE: usize, A> Into<_Tensor<$t, Cpu, DEVICE, A>> for $t where A: Allocator, A::Output: AllocatorOutputRetrive {
26 fn into(self) -> _Tensor<$t, Cpu, DEVICE, A> {
27 let mut ret = _Tensor::<$t, Cpu, DEVICE, A>::empty(vec![1]).unwrap();
28 ret.as_raw_mut()[0] = self;
29 return ret;
30 }
31 }
32 impl<const DEVICE: usize, A> Into<Tensor<$t, Cpu, DEVICE, A>> for $t where A: Allocator, A::Output: AllocatorOutputRetrive {
33 fn into(self) -> Tensor<$t, Cpu, DEVICE, A> {
34 Tensor {
35 inner: Arc::new(self.into()),
36 }
37 }
38 }
39 )*
40 };
41}
42
43macro_rules! impl_type_num {
44 (num, $($t:ident),*) => {
45 $(
46 impl TypeNum for $t {
47 fn type_num() -> Dtype {
48 return map_type_num!($t);
49 }
50 }
51 )*
52 };
53
54 (vec, $($t:ident),*) => {
55 $(
56 impl<const DEVICE: usize, A> From<Vec<$t>> for _Tensor<$t, Cpu, DEVICE, A> where A: Allocator, A::Output: AllocatorOutputRetrive {
57 fn from(data: Vec<$t>) -> Self {
58 let mut ptr = data.as_ptr() as *mut $t;
59 let length = data.len();
60 let res_shape = Shape::from(vec![length as i64]);
61 let layout;
62 let mut allocator = A::new();
63 if (ptr as usize) % 8 == 0 {
64 let _ = ManuallyDrop::new(data);
65 layout = Layout::from_size_align(length * std::mem::size_of::<$t>(), 8).unwrap();
66 allocator.insert_ptr(ptr as *mut u8, DEVICE);
67 } else {
68 layout = Layout::from_size_align(length * std::mem::size_of::<$t>(), 8).unwrap();
69 let allocate_res = allocator.allocate(layout, DEVICE).unwrap();
70 ptr = allocate_res.get_ptr() as *mut $t;
71 unsafe {
72 std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, length);
73 }
74 }
75 let ly = hpt_common::layout::layout::Layout::new(res_shape, vec![1]);
76 return _Tensor {
77 #[cfg(not(feature = "bound_check"))]
78 data: Pointer::new(ptr),
79 #[cfg(feature = "bound_check")]
80 data: Pointer::new(ptr, length as i64),
81 parent: None,
82 layout: ly,
83 mem_layout: Arc::new(layout),
84 backend: Backend::<Cpu>::new(ptr as u64, DEVICE, true),
85 phantom: PhantomData,
86 };
87 }
88 }
89 impl<const DEVICE: usize> From<Vec<$t>> for Tensor<$t, Cpu, DEVICE> {
90 fn from(data: Vec<$t>) -> Self {
91 Tensor {
92 inner: Arc::new(data.into()),
93 }
94 }
95 }
96 )*
97 };
98 (ndarray, $($generic:ident),*; $($vars:ident),*; $ct:ident, $($t:ident),*) => {
99 impl<$(const $generic: usize), *, const DEVICE: usize, A> From<repeate_generic!(nested_array_type, $($generic), *; $ct)> for _Tensor<$ct, Cpu, DEVICE, A> where A: Allocator, A::Output: AllocatorOutputRetrive {
100 fn from(data: repeate_generic!(nested_array_type, $($generic), *; $ct)) -> Self {
101 let mut vec: Vec<$ct> = Vec::with_capacity(repeate_generic!(operations, *, $($generic), *));
102 let shape = Shape::from(vec![$($generic as i64), *]);
103
104 repeate_generic!(iterate, data; vec; $($vars), *).for_each(|element| vec.push(element));
105 let mut ptr = vec.as_mut_ptr();
106 let length = repeate_generic!(mul, $($generic), *);
107 let layout;
108 let mut allocator = A::new();
109 if (ptr as usize) % 8 == 0 {
110 let _ = ManuallyDrop::new(vec);
111 layout = Layout::from_size_align(length * std::mem::size_of::<$ct>(), 8).unwrap();
112 allocator.insert_ptr(ptr as *mut u8, DEVICE);
113 } else {
114 layout = Layout::from_size_align(length * std::mem::size_of::<$ct>(), 8).unwrap();
115 let allocate_res = allocator.allocate(layout, DEVICE).unwrap();
116 ptr = allocate_res.get_ptr() as *mut $ct;
117 unsafe {
118 std::ptr::copy_nonoverlapping(vec.as_ptr(), ptr, vec.len());
119 }
120 }
121 let strides = shape_to_strides(&shape);
122 let ly = hpt_common::layout::layout::Layout::new(shape, strides);
123 return _Tensor {
124 #[cfg(not(feature = "bound_check"))]
125 data: Pointer::new(ptr),
126 #[cfg(feature = "bound_check")]
127 data: Pointer::new(ptr, length as i64),
128 parent: None,
129 layout: ly,
130 mem_layout: Arc::new(layout),
131 backend: Backend::<Cpu>::new(ptr as u64, DEVICE, true),
132 phantom: PhantomData,
133 };
134 }
135 }
136 impl<$(const $generic: usize), *, const DEVICE: usize, A> From<repeate_generic!(nested_array_type, $($generic), *; $ct)> for Tensor<$ct, Cpu, DEVICE, A> where A: Allocator, A::Output: AllocatorOutputRetrive {
137 fn from(data: repeate_generic!(nested_array_type, $($generic), *; $ct)) -> Self {
138 Tensor {
139 inner: Arc::new(data.into()),
140 }
141 }
142 }
143 impl_type_num!(ndarray, $($generic), *; $($vars), *; $($t),*);
144 };
145 (ndarray, $($generic:ident),*; $($vars:ident),*; $ct:ident) => {
146 impl<$(const $generic: usize), *, const DEVICE: usize, A> From<repeate_generic!(nested_array_type, $($generic), *; $ct)> for _Tensor<$ct, Cpu, DEVICE, A> where A: Allocator, A::Output: AllocatorOutputRetrive {
147 fn from(data: repeate_generic!(nested_array_type, $($generic), *; $ct)) -> Self {
148 let mut vec: Vec<$ct> = Vec::with_capacity(repeate_generic!(operations, *, $($generic), *));
149 let shape = Shape::from(vec![$($generic as i64), *]);
150
151 repeate_generic!(iterate, data; vec; $($vars), *).for_each(|element| vec.push(element));
152 let mut ptr = vec.as_mut_ptr();
153 let length = repeate_generic!(mul, $($generic), *);
154 let layout;
155 let mut allocator = A::new();
156 if (ptr as usize) % 8 == 0 {
157 let _ = ManuallyDrop::new(vec);
158 layout = Layout::from_size_align(length * std::mem::size_of::<$ct>(), 8).unwrap();
159 allocator.insert_ptr(ptr as *mut u8, DEVICE);
160 } else {
161 layout = Layout::from_size_align(length * std::mem::size_of::<$ct>(), 8).unwrap();
162 let allocate_res = allocator.allocate(layout, DEVICE).unwrap();
163 ptr = allocate_res.get_ptr() as *mut $ct;
164 unsafe {
165 std::ptr::copy_nonoverlapping(vec.as_ptr(), ptr, vec.len());
166 }
167 }
168 let strides = shape_to_strides(&shape);
169
170 let ly = hpt_common::layout::layout::Layout::new(shape, strides);
171 return _Tensor {
172 #[cfg(not(feature = "bound_check"))]
173 data: Pointer::new(ptr),
174 #[cfg(feature = "bound_check")]
175 data: Pointer::new(ptr, length as i64),
176 parent: None,
177 layout: ly,
178 mem_layout: Arc::new(layout),
179 backend: Backend::<Cpu>::new(ptr as u64, DEVICE, true),
180 phantom: PhantomData,
181 };
182 }
183 }
184 impl<$(const $generic: usize), *, const DEVICE: usize, A> From<repeate_generic!(nested_array_type, $($generic), *; $ct)> for Tensor<$ct, Cpu, DEVICE, A> where A: Allocator, A::Output: AllocatorOutputRetrive {
185 fn from(data: repeate_generic!(nested_array_type, $($generic), *; $ct)) -> Self {
186 Tensor {
187 inner: Arc::new(data.into()),
188 }
189 }
190 }
191 };
192
193 (
194 ndarray_source_target,
195 $source:ident,
196 $($generic:ident),*;
197 $($vars:ident),*;
198 $ct:ident,
199 $($t:ident),*
200 ) => {
201 impl<$(const $generic: usize), *, const DEVICE: usize, A> From<repeate_generic!(nested_array_type, $($generic), *; $source)> for Tensor<$ct, Cpu, DEVICE, A> where A: Allocator, A::Output: AllocatorOutputRetrive {
202 fn from(data: repeate_generic!(nested_array_type, $($generic), *; $source)) -> Self {
203 Tensor {
204 inner: Arc::new(data.into()),
205 }
206 }
207 }
208 impl_type_num!(ndarray_source_target, $source, $($generic), *; $($vars), *; $($t),*);
209 };
210 (ndarray_source_target, $source:ident, $($generic:ident),*; $($vars:ident),*; $ct:ident) => {
211 impl<$(const $generic: usize), *, const DEVICE: usize, A> From<repeate_generic!(nested_array_type, $($generic), *; $source)> for _Tensor<$ct, Cpu, DEVICE, A> where A: Allocator, A::Output: AllocatorOutputRetrive {
212 fn from(data: repeate_generic!(nested_array_type, $($generic), *; $source)) -> Self {
213 let mut vec: Vec<$ct> = Vec::with_capacity(repeate_generic!(operations, *, $($generic), *));
214 let shape = Shape::from(vec![$($generic as i64), *]);
215
216 repeate_generic!(iterate, data; vec; $($vars), *).for_each(|element| vec.push(element.into()));
217 let mut ptr = vec.as_mut_ptr();
218 let length = repeate_generic!(mul, $($generic), *);
219 let layout;
220 let mut allocator = A::new();
221 if (ptr as usize) % 8 == 0 {
222 let _ = ManuallyDrop::new(vec);
223 layout = Layout::from_size_align(length * std::mem::size_of::<$ct>(), 8).unwrap();
224 allocator.insert_ptr(ptr as *mut u8, DEVICE);
225 } else {
226 layout = Layout::from_size_align(length * std::mem::size_of::<$ct>(), 8).unwrap();
227 let allocate_res = allocator.allocate(layout, DEVICE).unwrap();
228 ptr = allocate_res.get_ptr() as *mut $ct;
229 unsafe {
230 std::ptr::copy_nonoverlapping(vec.as_ptr(), ptr, vec.len());
231 }
232 }
233 let strides = shape_to_strides(&shape);
234
235 let ly = hpt_common::layout::layout::Layout::new(shape, strides);
236 return _Tensor {
237 #[cfg(not(feature = "bound_check"))]
238 data: Pointer::new(ptr),
239 #[cfg(feature = "bound_check")]
240 data: Pointer::new(ptr, length as i64),
241 parent: None,
242 layout: ly,
243 mem_layout: Arc::new(layout),
244 backend: Backend::<Cpu>::new(ptr as u64, DEVICE, true),
245 phantom: PhantomData,
246 };
247 }
248 }
249 impl<$(const $generic: usize), *, const DEVICE: usize, A> From<repeate_generic!(nested_array_type, $($generic), *; $source)> for Tensor<$ct, Cpu, DEVICE, A> where A: Allocator, A::Output: AllocatorOutputRetrive {
250 fn from(data: repeate_generic!(nested_array_type, $($generic), *; $source)) -> Self {
251 Tensor {
252 inner: Arc::new(data.into()),
253 }
254 }
255 }
256 };
257
258 (ndarray_ref, $($generic:ident),*; $($vars:ident),*; $ct:ident, $($t:ident),*) => {
259 impl<$(const $generic: usize), *, const DEVICE: usize, A> From<&repeate_generic!(nested_array_type, $($generic), *; $ct)> for _Tensor<$ct, Cpu, DEVICE, A> where A: Allocator, A::Output: AllocatorOutputRetrive {
260 fn from(data: &repeate_generic!(nested_array_type, $($generic), *; $ct)) -> Self {
261 let mut vec: Vec<$ct> = Vec::with_capacity(repeate_generic!(operations, *, $($generic), *));
262 let shape = Shape::from(vec![$($generic as i64), *]);
263
264 repeate_generic!(iterate, data; vec; $($vars), *).for_each(|element| vec.push(*element));
265 let mut ptr = vec.as_mut_ptr();
266 let length = repeate_generic!(mul, $($generic), *);
267 let layout;
268 let mut allocator = A::new();
269 if (ptr as usize) % 8 == 0 {
270 let _ = ManuallyDrop::new(vec);
271 layout = Layout::from_size_align(length * std::mem::size_of::<$ct>(), 8).unwrap();
272 allocator.insert_ptr(ptr as *mut u8, DEVICE);
273 } else {
274 layout = Layout::from_size_align(length * std::mem::size_of::<$ct>(), 8).unwrap();
275 let allocate_res = allocator.allocate(layout, DEVICE).unwrap();
276 ptr = allocate_res.get_ptr() as *mut $ct;
277 unsafe {
278 std::ptr::copy_nonoverlapping(vec.as_ptr(), ptr, vec.len());
279 }
280 }
281 let strides = shape_to_strides(&shape);
282
283 let ly = hpt_common::layout::layout::Layout::new(shape, strides);
284 return _Tensor {
285 #[cfg(not(feature = "bound_check"))]
286 data: Pointer::new(ptr),
287 #[cfg(feature = "bound_check")]
288 data: Pointer::new(ptr, length as i64),
289 parent: None,
290 layout: ly,
291 mem_layout: Arc::new(layout),
292 backend: Backend::<Cpu>::new(ptr as u64, DEVICE, true),
293 phantom: PhantomData,
294 };
295 }
296 }
297 impl<$(const $generic: usize), *, const DEVICE: usize, A> From<&repeate_generic!(nested_array_type, $($generic), *; $ct)> for Tensor<$ct, Cpu, DEVICE, A> where A: Allocator, A::Output: AllocatorOutputRetrive {
298 fn from(data: &repeate_generic!(nested_array_type, $($generic), *; $ct)) -> Self {
299 Tensor {
300 inner: Arc::new(data.into()),
301 }
302 }
303 }
304 impl_type_num!(ndarray_ref, $($generic), *; $($vars), *; $($t),*);
305 };
306 (ndarray_ref, $($generic:ident),*; $($vars:ident),*; $ct:ident) => {
307 impl<$(const $generic: usize), *, const DEVICE: usize, A> From<&repeate_generic!(nested_array_type, $($generic), *; $ct)> for _Tensor<$ct, Cpu, DEVICE, A> where A: Allocator, A::Output: AllocatorOutputRetrive {
308 fn from(data: &repeate_generic!(nested_array_type, $($generic), *; $ct)) -> Self {
309 let mut vec: Vec<$ct> = Vec::with_capacity(repeate_generic!(operations, *, $($generic), *));
310 let shape = Shape::from(vec![$($generic as i64), *]);
311
312 repeate_generic!(iterate, data; vec; $($vars), *).for_each(|element| vec.push(*element));
313 let mut ptr = vec.as_mut_ptr();
314 let length = repeate_generic!(mul, $($generic), *);
315 let layout;
316 let mut allocator = A::new();
317 if (ptr as usize) % 8 == 0 {
318 let _ = ManuallyDrop::new(vec);
319 layout = Layout::from_size_align(length * std::mem::size_of::<$ct>(), 8).unwrap();
320 allocator.insert_ptr(ptr as *mut u8, DEVICE);
321 } else {
322 layout = Layout::from_size_align(length * std::mem::size_of::<$ct>(), 8).unwrap();
323 let allocate_res = allocator.allocate(layout, DEVICE).unwrap();
324 ptr = allocate_res.get_ptr() as *mut $ct;
325 unsafe {
326 std::ptr::copy_nonoverlapping(vec.as_ptr(), ptr, vec.len());
327 }
328 }
329 let strides = shape_to_strides(&shape);
330
331 let ly = hpt_common::layout::layout::Layout::new(shape, strides);
332 return _Tensor {
333 #[cfg(not(feature = "bound_check"))]
334 data: Pointer::new(ptr),
335 #[cfg(feature = "bound_check")]
336 data: Pointer::new(ptr, length as i64),
337 parent: None,
338 layout: ly,
339 mem_layout: Arc::new(layout),
340 backend: Backend::<Cpu>::new(ptr as u64, DEVICE, true),
341 phantom: PhantomData,
342 };
343 }
344 }
345 impl<$(const $generic: usize), *, const DEVICE: usize> From<&repeate_generic!(nested_array_type, $($generic), *; $ct)> for Tensor<$ct, Cpu, DEVICE> {
346 fn from(data: &repeate_generic!(nested_array_type, $($generic), *; $ct)) -> Self {
347 Tensor {
348 inner: Arc::new(data.into()),
349 }
350 }
351 }
352 };
353}
354
355macro_rules! repeate_generic {
357 (const, $($t:ident),*) => {
358 impl<$(const $t: usize), *>
359 };
360 (nested_array, $n:expr, $($t:expr),*; $data_type:ident) => {
361 [repeate_generic!(nested_array, $($t), *; $data_type);$n];
362 };
363 (nested_array, $t:expr; $data_type:ident) => {
364 [$data_type; $t]
365 };
366 (nested_array_type, $n:expr, $($t:expr),*; $data_type:ident) => {
367 [repeate_generic!(nested_array_type, $($t), *; $data_type);$n]
368 };
369 (nested_array_type, $t:expr; $data_type:ident) => {
370 [$data_type; $t]
371 };
372 (operations, $op:tt, $n:expr, $($t:expr),*) => {
373 $n $op repeate_generic!(operations, $op, $($t), *)
374 };
375 (operations, $op:tt, $n:expr) => {
376 $n
377 };
378 (iterate, $data:ident; $vec:ident; $n:ident, $($t:ident),*) => {
379 $data.into_iter().flat_map(|$n| repeate_generic!(iterate, $vec; $n;; $($t), *))
380 };
381 (iterate, $data:ident; $vec:ident; $n:ident) => {
382 $data.into_iter().flat_map(|$n| repeate_generic!(iterate, $vec; $n;;))
383 };
384 (iterate, $vec:ident; $n:ident; ; $n2:ident, $($t:ident),*) => {
385 $n.into_iter().flat_map(|$n2| repeate_generic!(iterate, $vec; $n2;; $($t), *))
386 };
387 (iterate, $vec:ident; $n:ident; ; $n2:ident) => {
388 $n.into_iter().flat_map(|$n2| repeate_generic!(iterate, $vec; $n2;;))
389 };
390 (iterate, $vec:ident; $n:ident; ;) => {
391 $n.into_iter()
392 };
393 (iterate, $data:ident; $vec:ident;) => {
394 $data.into_iter()
395 };
396 (mul, $n:expr, $($t:expr),*) => {
397 $n * repeate_generic!(mul, $($t), *)
398 };
399 (mul, $n:expr) => {
400 $n
401 };
402}
403
404from_scalar!(bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, bf16, f32, f64, Complex32, Complex64);
405impl_type_num!(
406 vec, bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64
407); impl_type_num!(ndarray, N; ; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
409impl_type_num!(ndarray_ref, N; ; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
410impl_type_num!(ndarray, N, M; i; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
411impl_type_num!(ndarray_ref, N, M; i; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
412impl_type_num!(ndarray, N, M, O; i, j; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
413impl_type_num!(ndarray_ref, N, M, O; i, j; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
414impl_type_num!(ndarray, N, M, O, P; i, j, k; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
415impl_type_num!(ndarray_ref, N, M, O, P; i, j, k; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
416impl_type_num!(ndarray, N, M, O, P, Q; i, j, k, l; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
417impl_type_num!(ndarray_ref, N, M, O, P, Q; i, j, k, l; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
418impl_type_num!(ndarray, N, M, O, P, Q, R; i, j, k, l, m; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
419impl_type_num!(ndarray_ref, N, M, O, P, Q, R; i, j, k, l, m; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
420impl_type_num!(ndarray, N, M, O, P, Q, R, S; i, j, k, l, m, n; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
421impl_type_num!(ndarray_ref, N, M, O, P, Q, R, S; i, j, k, l, m, n; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
422impl_type_num!(ndarray, N, M, O, P, Q, R, S, T; i, j, k, l, m, n, o; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
423impl_type_num!(ndarray_ref, N, M, O, P, Q, R, S, T; i, j, k, l, m, n, o; bool, i8, u8, i16, u16, i32, u32, i64, u64, f16, f32, f64, Complex32, Complex64);
424impl_type_num!(ndarray_source_target, f32, N; ; Complex32);
425impl_type_num!(ndarray_source_target, f64, N; ; Complex64);
426impl_type_num!(ndarray_source_target, f32, N, M; i; Complex32);
427impl_type_num!(ndarray_source_target, f64, N, M; i; Complex64);
428impl_type_num!(ndarray_source_target, f32, N, M, O; i, j; Complex32);
429impl_type_num!(ndarray_source_target, f64, N, M, O; i, j; Complex64);
430impl_type_num!(ndarray_source_target, f32, N, M, O, P; i, j, k; Complex32);
431impl_type_num!(ndarray_source_target, f64, N, M, O, P; i, j, k; Complex64);
432impl_type_num!(ndarray_source_target, f32, N, M, O, P, Q; i, j, k, l; Complex32);
433impl_type_num!(ndarray_source_target, f64, N, M, O, P, Q; i, j, k, l; Complex64);
434impl_type_num!(ndarray_source_target, f32, N, M, O, P, Q, R; i, j, k, l, m; Complex32);
435impl_type_num!(ndarray_source_target, f64, N, M, O, P, Q, R; i, j, k, l, m; Complex64);
436impl_type_num!(ndarray_source_target, f32, N, M, O, P, Q, R, S; i, j, k, l, m, n; Complex32);
437impl_type_num!(ndarray_source_target, f64, N, M, O, P, Q, R, S; i, j, k, l, m, n; Complex64);
438impl_type_num!(ndarray_source_target, f32, N, M, O, P, Q, R, S, T; i, j, k, l, m, n, o; Complex32);
439impl_type_num!(ndarray_source_target, f64, N, M, O, P, Q, R, S, T; i, j, k, l, m, n, o; Complex64);
440
441impl<T, const DEVICE: usize, Al> Tensor<T, Cpu, DEVICE, Al>
442where
443 Al: Allocator,
444{
445 pub fn new<A>(data: A) -> Self
447 where
448 A: Into<Tensor<T, Cpu, DEVICE, Al>>,
449 {
450 data.into()
451 }
452}
453
454impl<T, const DEVICE: usize, Al> DiffTensor<T, Cpu, DEVICE, Al>
455where
456 Al: Allocator,
457{
458 pub fn new<A>(data: A) -> Self
460 where
461 A: Into<Tensor<T, Cpu, DEVICE, Al>>,
462 {
463 let ret = data.into();
464 DiffTensor {
465 inner: ret,
466 grad: Rc::new(RefCell::new(None)),
467 out_degree: Rc::new(RefCell::new(0)),
468 backward: Rc::new(RefCell::new(move |_| Ok(true))),
469 }
470 }
471}