1use std::{cell::RefCell, rc::Rc};
2
3use crate::{
4 tensor::{DiffTensor, Tensor},
5 tensor_base::_Tensor,
6 BoolVector, Cpu,
7};
8use hpt_allocator::traits::{Allocator, AllocatorOutputRetrive};
9use hpt_common::{error::base::TensorError, shape::shape::Shape};
10use hpt_traits::{CommonBounds, TensorCreator};
11use hpt_types::{dtype::TypeCommon, into_scalar::Cast, type_promote::NormalOut};
12
13impl<T: CommonBounds, const DEVICE: usize, Al> TensorCreator<T> for Tensor<T, Cpu, DEVICE, Al>
14where
15 Al: Allocator,
16 Al::Output: AllocatorOutputRetrive,
17{
18 type Output = Tensor<T, Cpu, DEVICE, Al>;
19
20 fn empty<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError> {
21 Ok(_Tensor::<T, Cpu, DEVICE, Al>::empty(shape)?.into())
22 }
23
24 fn zeros<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError> {
25 Ok(_Tensor::<T, Cpu, DEVICE, Al>::zeros(shape)?.into())
26 }
27
28 fn ones<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError>
29 where
30 u8: Cast<T>,
31 {
32 Ok(_Tensor::<T, Cpu, DEVICE, Al>::ones(shape)?.into())
33 }
34
35 fn empty_like(&self) -> Result<Self::Output, TensorError> {
36 Ok(_Tensor::empty_like(self.inner.as_ref())?.into())
37 }
38
39 fn zeros_like(&self) -> Result<Self::Output, TensorError> {
40 Ok(_Tensor::zeros_like(self.inner.as_ref())?.into())
41 }
42
43 fn ones_like(&self) -> Result<Self::Output, TensorError>
44 where
45 u8: Cast<T>,
46 {
47 Ok(_Tensor::ones_like(self.inner.as_ref())?.into())
48 }
49
50 fn full<S: Into<Shape>>(val: T, shape: S) -> Result<Self::Output, TensorError> {
51 Ok(_Tensor::<T, Cpu, DEVICE, Al>::full(val, shape)?.into())
52 }
53
54 fn full_like(&self, val: T) -> Result<Self::Output, TensorError> {
55 Ok(_Tensor::full_like(self.inner.as_ref(), val)?.into())
56 }
57
58 fn arange<U>(start: U, end: U) -> Result<Self::Output, TensorError>
59 where
60 usize: Cast<T>,
61 U: Cast<i64> + Cast<T> + Copy,
62 {
63 Ok(_Tensor::<T, Cpu, DEVICE, Al>::arange(start, end)?.into())
64 }
65
66 fn arange_step(start: T, end: T, step: T) -> Result<Self::Output, TensorError>
67 where
68 T: Cast<f64> + Cast<usize>,
69 usize: Cast<T>,
70 {
71 Ok(_Tensor::<T, Cpu, DEVICE, Al>::arange_step(start, end, step)?.into())
72 }
73
74 fn eye(n: usize, m: usize, k: usize) -> Result<Self::Output, TensorError> {
75 Ok(_Tensor::<T, Cpu, DEVICE, Al>::eye(n, m, k)?.into())
76 }
77
78 fn linspace<U>(
79 start: U,
80 end: U,
81 num: usize,
82 include_end: bool,
83 ) -> Result<Self::Output, TensorError>
84 where
85 U: Cast<f64> + Cast<T> + Copy,
86 usize: Cast<T>,
87 f64: Cast<T>,
88 {
89 Ok(_Tensor::<T, Cpu, DEVICE, Al>::linspace(start, end, num, include_end)?.into())
90 }
91
92 fn logspace(
93 start: T,
94 end: T,
95 num: usize,
96 include_end: bool,
97 base: T,
98 ) -> Result<Self::Output, TensorError>
99 where
100 T: Cast<f64> + num::Float + NormalOut<T, Output = T>,
101 usize: Cast<T>,
102 f64: Cast<T>,
103 {
104 Ok(_Tensor::<T, Cpu, DEVICE, Al>::logspace(start, end, num, include_end, base)?.into())
105 }
106
107 fn geomspace(start: T, end: T, n: usize, include_end: bool) -> Result<Self::Output, TensorError>
108 where
109 f64: Cast<T>,
110 usize: Cast<T>,
111 T: Cast<f64>,
112 {
113 Ok(_Tensor::<T, Cpu, DEVICE, Al>::geomspace(start, end, n, include_end)?.into())
114 }
115
116 fn tri(n: usize, m: usize, k: i64, low_triangle: bool) -> Result<Self::Output, TensorError>
117 where
118 u8: Cast<T>,
119 {
120 Ok(_Tensor::<T, Cpu, DEVICE, Al>::tri(n, m, k, low_triangle)?.into())
121 }
122
123 fn tril(&self, k: i64) -> Result<Self::Output, TensorError>
124 where
125 T: NormalOut<bool, Output = T> + Cast<T> + TypeCommon,
126 T::Vec: NormalOut<BoolVector, Output = T::Vec>,
127 {
128 Ok(_Tensor::tril(self.inner.as_ref(), k)?.into())
129 }
130
131 fn triu(&self, k: i64) -> Result<Self::Output, TensorError>
132 where
133 T: NormalOut<bool, Output = T> + Cast<T> + TypeCommon,
134 T::Vec: NormalOut<BoolVector, Output = T::Vec>,
135 {
136 Ok(_Tensor::triu(self.inner.as_ref(), k)?.into())
137 }
138
139 fn identity(n: usize) -> Result<Self::Output, TensorError>
140 where
141 u8: Cast<T>,
142 {
143 Ok(_Tensor::<T, Cpu, DEVICE, Al>::identity(n)?.into())
144 }
145}
146
147impl<T: CommonBounds, const DEVICE: usize, Al> TensorCreator<T> for DiffTensor<T, Cpu, DEVICE, Al>
148where
149 Al: Allocator,
150 Al::Output: AllocatorOutputRetrive,
151{
152 type Output = DiffTensor<T, Cpu, DEVICE, Al>;
153
154 fn empty<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError> {
155 let ret = Tensor::<T, Cpu, DEVICE, Al>::empty(shape)?;
156 Ok(DiffTensor {
157 inner: ret,
158 grad: Rc::new(RefCell::new(None)),
159 out_degree: Rc::new(RefCell::new(0)),
160 backward: Rc::new(RefCell::new(move |_| Ok(true))),
161 })
162 }
163
164 fn zeros<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError> {
165 let ret = Tensor::<T, Cpu, DEVICE, Al>::zeros(shape)?;
166 Ok(DiffTensor {
167 inner: ret,
168 grad: Rc::new(RefCell::new(None)),
169 out_degree: Rc::new(RefCell::new(0)),
170 backward: Rc::new(RefCell::new(move |_| Ok(true))),
171 })
172 }
173
174 fn ones<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError>
175 where
176 u8: Cast<T>,
177 {
178 let ret = Tensor::<T, Cpu, DEVICE, Al>::ones(shape)?;
179 Ok(DiffTensor {
180 inner: ret,
181 grad: Rc::new(RefCell::new(None)),
182 out_degree: Rc::new(RefCell::new(0)),
183 backward: Rc::new(RefCell::new(move |_| Ok(true))),
184 })
185 }
186
187 fn empty_like(&self) -> Result<Self::Output, TensorError> {
188 let ret = self.inner.empty_like()?;
189 Ok(DiffTensor {
190 inner: ret,
191 grad: Rc::new(RefCell::new(None)),
192 out_degree: Rc::new(RefCell::new(0)),
193 backward: Rc::new(RefCell::new(move |_| Ok(true))),
194 })
195 }
196
197 fn zeros_like(&self) -> Result<Self::Output, TensorError> {
198 let ret = self.inner.zeros_like()?;
199 Ok(DiffTensor {
200 inner: ret,
201 grad: Rc::new(RefCell::new(None)),
202 out_degree: Rc::new(RefCell::new(0)),
203 backward: Rc::new(RefCell::new(move |_| Ok(true))),
204 })
205 }
206
207 fn ones_like(&self) -> Result<Self::Output, TensorError>
208 where
209 u8: Cast<T>,
210 {
211 let ret = self.inner.ones_like()?;
212 Ok(DiffTensor {
213 inner: ret,
214 grad: Rc::new(RefCell::new(None)),
215 out_degree: Rc::new(RefCell::new(0)),
216 backward: Rc::new(RefCell::new(move |_| Ok(true))),
217 })
218 }
219
220 fn full<S: Into<Shape>>(val: T, shape: S) -> Result<Self::Output, TensorError> {
221 let ret = Tensor::full(val, shape)?;
222 Ok(DiffTensor {
223 inner: ret,
224 grad: Rc::new(RefCell::new(None)),
225 out_degree: Rc::new(RefCell::new(0)),
226 backward: Rc::new(RefCell::new(move |_| Ok(true))),
227 })
228 }
229
230 fn full_like(&self, val: T) -> Result<Self::Output, TensorError> {
231 let ret = self.inner.full_like(val)?;
232 Ok(DiffTensor {
233 inner: ret,
234 grad: Rc::new(RefCell::new(None)),
235 out_degree: Rc::new(RefCell::new(0)),
236 backward: Rc::new(RefCell::new(move |_| Ok(true))),
237 })
238 }
239
240 fn arange<U>(start: U, end: U) -> Result<Self::Output, TensorError>
241 where
242 usize: Cast<T>,
243 U: Cast<i64> + Cast<T> + Copy,
244 {
245 let ret = Tensor::arange(start, end)?;
246 Ok(DiffTensor {
247 inner: ret,
248 grad: Rc::new(RefCell::new(None)),
249 out_degree: Rc::new(RefCell::new(0)),
250 backward: Rc::new(RefCell::new(move |_| Ok(true))),
251 })
252 }
253
254 fn arange_step(start: T, end: T, step: T) -> Result<Self::Output, TensorError>
255 where
256 T: Cast<f64> + Cast<usize>,
257 usize: Cast<T>,
258 {
259 let ret = Tensor::arange_step(start, end, step)?;
260 Ok(DiffTensor {
261 inner: ret,
262 grad: Rc::new(RefCell::new(None)),
263 out_degree: Rc::new(RefCell::new(0)),
264 backward: Rc::new(RefCell::new(move |_| Ok(true))),
265 })
266 }
267
268 fn eye(n: usize, m: usize, k: usize) -> Result<Self::Output, TensorError> {
269 let ret = Tensor::<T, Cpu, DEVICE, Al>::eye(n, m, k)?;
270 Ok(DiffTensor {
271 inner: ret,
272 grad: Rc::new(RefCell::new(None)),
273 out_degree: Rc::new(RefCell::new(0)),
274 backward: Rc::new(RefCell::new(move |_| Ok(true))),
275 })
276 }
277
278 fn linspace<U>(
279 start: U,
280 end: U,
281 num: usize,
282 include_end: bool,
283 ) -> Result<Self::Output, TensorError>
284 where
285 U: Cast<f64> + Cast<T> + Copy,
286 usize: Cast<T>,
287 f64: Cast<T>,
288 {
289 let ret = Tensor::linspace(start, end, num, include_end)?;
290 Ok(DiffTensor {
291 inner: ret,
292 grad: Rc::new(RefCell::new(None)),
293 out_degree: Rc::new(RefCell::new(0)),
294 backward: Rc::new(RefCell::new(move |_| Ok(true))),
295 })
296 }
297
298 fn logspace(
299 start: T,
300 end: T,
301 num: usize,
302 include_end: bool,
303 base: T,
304 ) -> Result<Self::Output, TensorError>
305 where
306 T: Cast<f64> + num::Float + NormalOut<T, Output = T>,
307 usize: Cast<T>,
308 f64: Cast<T>,
309 {
310 let ret = Tensor::logspace(start, end, num, include_end, base)?;
311 Ok(DiffTensor {
312 inner: ret,
313 grad: Rc::new(RefCell::new(None)),
314 out_degree: Rc::new(RefCell::new(0)),
315 backward: Rc::new(RefCell::new(move |_| Ok(true))),
316 })
317 }
318
319 fn geomspace(start: T, end: T, n: usize, include_end: bool) -> Result<Self::Output, TensorError>
320 where
321 f64: Cast<T>,
322 usize: Cast<T>,
323 T: Cast<f64>,
324 {
325 let ret = Tensor::geomspace(start, end, n, include_end)?;
326 Ok(DiffTensor {
327 inner: ret,
328 grad: Rc::new(RefCell::new(None)),
329 out_degree: Rc::new(RefCell::new(0)),
330 backward: Rc::new(RefCell::new(move |_| Ok(true))),
331 })
332 }
333
334 fn tri(n: usize, m: usize, k: i64, low_triangle: bool) -> Result<Self::Output, TensorError>
335 where
336 u8: Cast<T>,
337 {
338 let ret = Tensor::tri(n, m, k, low_triangle)?;
339 Ok(DiffTensor {
340 inner: ret,
341 grad: Rc::new(RefCell::new(None)),
342 out_degree: Rc::new(RefCell::new(0)),
343 backward: Rc::new(RefCell::new(move |_| Ok(true))),
344 })
345 }
346
347 fn tril(&self, k: i64) -> Result<Self::Output, TensorError>
348 where
349 T: NormalOut<bool, Output = T> + Cast<T> + TypeCommon,
350 T::Vec: NormalOut<BoolVector, Output = T::Vec>,
351 {
352 let ret = self.inner.tril(k)?;
353 Ok(DiffTensor {
354 inner: ret,
355 grad: Rc::new(RefCell::new(None)),
356 out_degree: Rc::new(RefCell::new(0)),
357 backward: Rc::new(RefCell::new(move |_| unimplemented!())),
358 })
359 }
360
361 fn triu(&self, k: i64) -> Result<Self::Output, TensorError>
362 where
363 T: NormalOut<bool, Output = T> + Cast<T> + TypeCommon,
364 T::Vec: NormalOut<BoolVector, Output = T::Vec>,
365 {
366 let ret = self.inner.triu(k)?;
367 Ok(DiffTensor {
368 inner: ret,
369 grad: Rc::new(RefCell::new(None)),
370 out_degree: Rc::new(RefCell::new(0)),
371 backward: Rc::new(RefCell::new(move |_| unimplemented!())),
372 })
373 }
374
375 fn identity(n: usize) -> Result<Self::Output, TensorError>
376 where
377 u8: Cast<T>,
378 {
379 let ret = Tensor::identity(n)?;
380 Ok(DiffTensor {
381 inner: ret,
382 grad: Rc::new(RefCell::new(None)),
383 out_degree: Rc::new(RefCell::new(0)),
384 backward: Rc::new(RefCell::new(move |_| Ok(true))),
385 })
386 }
387}