concision_traits/tensor/
ndtensor.rs

1/*
2    Appellation: ndtensor <module>
3    Created At: 2025.11.26:14:27:51
4    Contrib: @FL03
5*/
6use ndarray::{ArrayBase, Data, DataMut, Dimension, OwnedRepr, RawData, RawDataMut};
7use num_traits::Float;
8
9pub trait TensorBase<S, D, A> {
10    type Cont<_S, _D, _A>
11    where
12        _D: Dimension,
13        _S: RawData<Elem = _A>;
14
15    fn rank(&self) -> usize;
16
17    fn shape(&self) -> &[usize];
18
19    fn size(&self) -> usize;
20}
21
22pub trait NdTensor<S, D, A = <S as RawData>::Elem>:
23    TensorBase<S, D, A, Cont<S, D, A> = ArrayBase<S, D, A>>
24where
25    D: Dimension,
26    S: RawData<Elem = A>,
27{
28    fn as_ptr(&self) -> *const A;
29
30    fn as_mut_ptr(&mut self) -> *mut A
31    where
32        S: RawDataMut;
33
34    fn apply<F, B>(&self, f: F) -> Self::Cont<OwnedRepr<B>, D, B>
35    where
36        F: FnMut(A) -> B,
37        A: Clone,
38        S: Data;
39
40    fn powi(&self, n: i32) -> Self::Cont<OwnedRepr<A>, D, A>
41    where
42        A: Float,
43        S: DataMut,
44    {
45        self.apply(|x| x.powi(n))
46    }
47
48    fn exp(&self) -> Self::Cont<OwnedRepr<A>, D, A>
49    where
50        A: Float,
51        S: DataMut,
52    {
53        self.apply(|x| x.exp())
54    }
55
56    fn log(&self) -> Self::Cont<OwnedRepr<A>, D, A>
57    where
58        A: Float,
59        S: DataMut,
60    {
61        self.apply(|x| x.ln())
62    }
63
64    fn ln(&self) -> Self::Cont<OwnedRepr<A>, D, A>
65    where
66        A: Float,
67        S: DataMut,
68    {
69        self.apply(|x| x.ln())
70    }
71
72    fn cos(&self) -> Self::Cont<OwnedRepr<A>, D, A>
73    where
74        A: Float,
75        S: DataMut,
76    {
77        self.apply(|x| x.cos())
78    }
79
80    fn cosh(&self) -> Self::Cont<OwnedRepr<A>, D, A>
81    where
82        A: Float,
83        S: DataMut,
84    {
85        self.apply(|x| x.cosh())
86    }
87
88    fn sin(&self) -> Self::Cont<OwnedRepr<A>, D, A>
89    where
90        A: Float,
91        S: DataMut,
92    {
93        self.apply(|x| x.sin())
94    }
95
96    fn sinh(&self) -> Self::Cont<OwnedRepr<A>, D, A>
97    where
98        A: Float,
99        S: DataMut,
100    {
101        self.apply(|x| x.sinh())
102    }
103
104    fn tan(&self) -> Self::Cont<OwnedRepr<A>, D, A>
105    where
106        A: Float,
107        S: DataMut,
108    {
109        self.apply(|x| x.tan())
110    }
111
112    fn tanh(&self) -> Self::Cont<OwnedRepr<A>, D, A>
113    where
114        A: Float,
115        S: DataMut,
116    {
117        self.apply(|x| x.tanh())
118    }
119}
120
121pub trait NdGradient<S, D, A = <S as RawData>::Elem>: NdTensor<S, D, A>
122where
123    D: Dimension,
124    S: RawData<Elem = A>,
125{
126    type Delta<_S, _D, _A>: NdTensor<_S, _D, _A>
127    where
128        _D: Dimension,
129        _S: RawData<Elem = _A>;
130
131    fn grad(&self, rhs: &Self::Delta<S, D, A>) -> Self::Delta<S, D, A>;
132}
133
134/*
135 ************* Implementations *************
136*/
137
138impl<A, S, D> TensorBase<S, D, A> for ArrayBase<S, D, A>
139where
140    D: Dimension,
141    S: RawData<Elem = A>,
142{
143    type Cont<_S, _D, _A>
144        = ArrayBase<_S, _D, _A>
145    where
146        _D: Dimension,
147        _S: RawData<Elem = _A>;
148
149    fn rank(&self) -> usize {
150        self.ndim()
151    }
152
153    fn shape(&self) -> &[usize] {
154        self.shape()
155    }
156
157    fn size(&self) -> usize {
158        self.len()
159    }
160}
161
162impl<A, S, D> NdTensor<S, D, A> for ArrayBase<S, D, A>
163where
164    D: Dimension,
165    S: RawData<Elem = A>,
166{
167    fn as_ptr(&self) -> *const A {
168        self.as_ptr()
169    }
170
171    fn as_mut_ptr(&mut self) -> *mut A
172    where
173        S: RawDataMut,
174    {
175        self.as_mut_ptr()
176    }
177
178    fn apply<F, B>(&self, f: F) -> Self::Cont<OwnedRepr<B>, D, B>
179    where
180        A: Clone,
181        F: FnMut(A) -> B,
182        S: Data,
183    {
184        self.mapv(f)
185    }
186}