1use ndarray::{ArrayBase, Data, DataMut, Dimension, OwnedRepr, RawData, RawDataMut};
7use num_traits::Float;
8
9pub trait TensorBase<S, D, A> {
10 type Cont<_S, _D, _A>
11 where
12 _D: Dimension,
13 _S: RawData<Elem = _A>;
14
15 fn rank(&self) -> usize;
16
17 fn shape(&self) -> &[usize];
18
19 fn size(&self) -> usize;
20}
21
22pub trait NdTensor<S, D, A = <S as RawData>::Elem>:
23 TensorBase<S, D, A, Cont<S, D, A> = ArrayBase<S, D, A>>
24where
25 D: Dimension,
26 S: RawData<Elem = A>,
27{
28 fn as_ptr(&self) -> *const A;
29
30 fn as_mut_ptr(&mut self) -> *mut A
31 where
32 S: RawDataMut;
33
34 fn apply<F, B>(&self, f: F) -> Self::Cont<OwnedRepr<B>, D, B>
35 where
36 F: FnMut(A) -> B,
37 A: Clone,
38 S: Data;
39
40 fn powi(&self, n: i32) -> Self::Cont<OwnedRepr<A>, D, A>
41 where
42 A: Float,
43 S: DataMut,
44 {
45 self.apply(|x| x.powi(n))
46 }
47
48 fn exp(&self) -> Self::Cont<OwnedRepr<A>, D, A>
49 where
50 A: Float,
51 S: DataMut,
52 {
53 self.apply(|x| x.exp())
54 }
55
56 fn log(&self) -> Self::Cont<OwnedRepr<A>, D, A>
57 where
58 A: Float,
59 S: DataMut,
60 {
61 self.apply(|x| x.ln())
62 }
63
64 fn ln(&self) -> Self::Cont<OwnedRepr<A>, D, A>
65 where
66 A: Float,
67 S: DataMut,
68 {
69 self.apply(|x| x.ln())
70 }
71
72 fn cos(&self) -> Self::Cont<OwnedRepr<A>, D, A>
73 where
74 A: Float,
75 S: DataMut,
76 {
77 self.apply(|x| x.cos())
78 }
79
80 fn cosh(&self) -> Self::Cont<OwnedRepr<A>, D, A>
81 where
82 A: Float,
83 S: DataMut,
84 {
85 self.apply(|x| x.cosh())
86 }
87
88 fn sin(&self) -> Self::Cont<OwnedRepr<A>, D, A>
89 where
90 A: Float,
91 S: DataMut,
92 {
93 self.apply(|x| x.sin())
94 }
95
96 fn sinh(&self) -> Self::Cont<OwnedRepr<A>, D, A>
97 where
98 A: Float,
99 S: DataMut,
100 {
101 self.apply(|x| x.sinh())
102 }
103
104 fn tan(&self) -> Self::Cont<OwnedRepr<A>, D, A>
105 where
106 A: Float,
107 S: DataMut,
108 {
109 self.apply(|x| x.tan())
110 }
111
112 fn tanh(&self) -> Self::Cont<OwnedRepr<A>, D, A>
113 where
114 A: Float,
115 S: DataMut,
116 {
117 self.apply(|x| x.tanh())
118 }
119}
120
121pub trait NdGradient<S, D, A = <S as RawData>::Elem>: NdTensor<S, D, A>
122where
123 D: Dimension,
124 S: RawData<Elem = A>,
125{
126 type Delta<_S, _D, _A>: NdTensor<_S, _D, _A>
127 where
128 _D: Dimension,
129 _S: RawData<Elem = _A>;
130
131 fn grad(&self, rhs: &Self::Delta<S, D, A>) -> Self::Delta<S, D, A>;
132}
133
134impl<A, S, D> TensorBase<S, D, A> for ArrayBase<S, D, A>
139where
140 D: Dimension,
141 S: RawData<Elem = A>,
142{
143 type Cont<_S, _D, _A>
144 = ArrayBase<_S, _D, _A>
145 where
146 _D: Dimension,
147 _S: RawData<Elem = _A>;
148
149 fn rank(&self) -> usize {
150 self.ndim()
151 }
152
153 fn shape(&self) -> &[usize] {
154 self.shape()
155 }
156
157 fn size(&self) -> usize {
158 self.len()
159 }
160}
161
162impl<A, S, D> NdTensor<S, D, A> for ArrayBase<S, D, A>
163where
164 D: Dimension,
165 S: RawData<Elem = A>,
166{
167 fn as_ptr(&self) -> *const A {
168 self.as_ptr()
169 }
170
171 fn as_mut_ptr(&mut self) -> *mut A
172 where
173 S: RawDataMut,
174 {
175 self.as_mut_ptr()
176 }
177
178 fn apply<F, B>(&self, f: F) -> Self::Cont<OwnedRepr<B>, D, B>
179 where
180 A: Clone,
181 F: FnMut(A) -> B,
182 S: Data,
183 {
184 self.mapv(f)
185 }
186}