flashlight_tensor/tensor.rs
1/// The main Tensor struct
2/// with data and sizes order by [... , z, y, x]
3#[derive(Clone)]
4pub struct Tensor<T>{
5 data: Vec<T>,
6 //..., z, y, x
7 sizes: Vec<u32>,
8}
9
10impl<T: Default + Clone> Tensor<T>{
11 /// Creates a new tensor with sizes
12 /// and default values of each element
13 ///
14 /// # Example
15 /// ```
16 /// use flashlight_tensor::prelude::*;
17 /// //a =
18 /// //[0.0, 0.0]
19 /// //[0.0, 0.0]
20 /// let a: Tensor<f32> = Tensor::new(&[2, 2]);
21 ///
22 /// assert_eq!(a.get_data(), &vec!{0.0, 0.0, 0.0, 0.0});
23 /// ```
24 pub fn new(_sizes: &[u32]) -> Tensor<T>{
25 let mut total_size: u32 = 1;
26 for i in 0.._sizes.len(){
27 total_size *= _sizes[i];
28 }
29
30 Self{
31 data: vec![T::default(); total_size as usize],
32 sizes: _sizes.to_vec(),
33 }
34 }
35
36 /// Creates a new tensor from data
37 /// with certain size, or None
38 /// if data does not fit in sizes
39 ///
40 /// # Example
41 /// ```
42 /// use flashlight_tensor::prelude::*;
43 /// //a =
44 /// //[1.0, 2.0]
45 /// //[3.0, 4.0]
46 /// let a: Tensor<f32> = Tensor::from_data(&vec!{1.0, 2.0, 3.0, 4.0}, &[2, 2]).unwrap();
47 /// assert_eq!(a.get_data(), &vec!{1.0, 2.0, 3.0, 4.0});
48 /// ```
49 pub fn from_data(_data: &[T], _sizes: &[u32]) -> Option<Self>{
50 if _sizes.iter().product::<u32>() as usize != _data.len(){
51 return None;
52 }
53
54 Some(Self{
55 data: _data.to_vec(),
56 sizes: _sizes.to_vec(),
57 })
58 }
59
60 /// Creates a new tensor filled
61 /// with one element
62 /// with certain size
63 ///
64 /// # Example
65 /// ```
66 /// use flashlight_tensor::prelude::*;
67 /// //a =
68 /// //[1.0, 1.0]
69 /// //[1.0, 1.0]
70 /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
71 ///
72 /// assert_eq!(a.get_data(), &vec!{1.0, 1.0, 1.0, 1.0});
73 /// ```
74 pub fn fill(fill_data: T, _sizes: &[u32]) -> Self{
75 let full_size: u32 = _sizes.iter().product();
76
77 Self{
78 data: vec![fill_data; full_size as usize],
79 sizes: _sizes.to_vec(),
80 }
81 }
82
83 /// Returns reference to data in tensor
84 ///
85 /// # Example
86 /// ```
87 /// use flashlight_tensor::prelude::*;
88 /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
89 ///
90 /// //b = &{1.0, 1.0, 1.0, 1.0}
91 /// let b = a.get_data();
92 ///
93 /// assert_eq!(a.get_data(), &vec!{1.0, 1.0, 1.0, 1.0});
94 /// ```
95 pub fn get_data(&self) -> &Vec<T>{
96 return &self.data;
97 }
98
99 /// Returns reference to sizes in tensor
100 ///
101 /// # Example
102 /// ```
103 /// use flashlight_tensor::prelude::*;
104 /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
105 ///
106 /// //b = &{2, 2}
107 /// let b = a.get_sizes();
108 ///
109 /// assert_eq!(a.get_sizes(), &vec!{2, 2});
110 /// ```
111 pub fn get_sizes(&self) -> &Vec<u32>{
112 return &self.sizes;
113 }
114 /// returns new tensor with data of first tensor + data of second tensor
115 /// with size[0] = tensor1.size[0] + tensor2.size[0]
116 /// only when tensor1.size[1..] == tensor2.size[1..]
117 ///
118 /// # Example
119 /// ```
120 /// use flashlight_tensor::prelude::*;
121 /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
122 /// let b: Tensor<f32> = Tensor::fill(2.0, &[2, 2]);
123 ///
124 /// //c.data = {1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0}
125 /// //c.sizes = {4, 2}
126 /// let c: Tensor<f32> = a.append(&b).unwrap();
127 ///
128 /// assert_eq!(c.get_data(), &vec!{1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0});
129 /// assert_eq!(c.get_sizes(), &vec!{4, 2});
130 /// ```
131 pub fn append(&self, tens2: &Tensor<T>) -> Option<Self>{
132 if (self.sizes.len() != 1 || tens2.sizes.len() != 1) && self.get_sizes()[1..].to_vec() != tens2.get_sizes()[1..].to_vec(){
133 return None;
134 }
135
136 let mut return_data: Vec<T> = self.get_data().clone();
137 let mut append_data: Vec<T> = tens2.get_data().clone();
138
139 return_data.append(&mut append_data);
140
141 let mut return_sizes = self.get_sizes().clone();
142 return_sizes[0] += tens2.get_sizes()[0];
143
144 Some(Self{
145 data: return_data,
146 sizes: return_sizes,
147 })
148 }
149 /// counts elements in tensor
150 ///
151 /// # Example
152 /// ```
153 /// use flashlight_tensor::prelude::*;
154 /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
155 ///
156 /// //count = 4
157 /// let count = a.count_data();
158 ///
159 /// assert_eq!(a.count_data(), 4);
160 /// ```
161 pub fn count_data(&self) -> usize{
162 self.get_data().len()
163 }
164
165 /// Change the size of tensor if the full size of new_sizes is equal to data.len() stored in
166 /// tensor.
167 ///
168 /// # Example
169 /// ```
170 /// use flashlight_tensor::prelude::*;
171 /// let mut a: Tensor<f32> = Tensor::fill(1.0, &[4]);
172 ///
173 /// a.set_size(&[1, 4]);
174 ///
175 /// assert_eq!(a.get_sizes(), &vec!{1, 4});
176 /// ```
177 pub fn set_size(&mut self, new_sizes: &[u32]){
178
179 let sizes_prod: u32 = new_sizes.iter().product();
180
181 if(sizes_prod as usize != self.data.len()){
182 return;
183 }
184
185 self.sizes = new_sizes.to_vec();
186 }
187
188 /// Change the data of tensor if the new data has length equal to current data length
189 ///
190 /// # Example
191 /// ```
192 /// use flashlight_tensor::prelude::*;
193 /// let mut a: Tensor<f32> = Tensor::fill(1.0, &[4]);
194 ///
195 /// a.set_data(&[2.0, 3.0, 4.0, 5.0]);
196 ///
197 /// assert_eq!(a.get_data(), &vec!{2.0, 3.0, 4.0, 5.0});
198 /// ```
199 pub fn set_data(&mut self, new_data: &[T]){
200 if new_data.len() != self.data.len(){
201 return;
202 }
203
204 self.data = new_data.to_vec();
205 }
206}
207impl<T> Tensor<T>{
208 /// returns an element on position
209 ///
210 /// # Example
211 /// ```
212 /// use flashlight_tensor::prelude::*;
213 /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
214 ///
215 /// //b = 1.0
216 /// let b = a.value(&[0, 0]).unwrap();
217 ///
218 /// assert_eq!(b, &1.0);
219 /// ```
220 pub fn value(&self, pos: &[u32]) -> Option<&T>{
221 let self_dimensions = self.sizes.len();
222 let selector_dimensions = pos.len();
223 if self_dimensions - selector_dimensions != 0{
224 return None;
225 }
226
227 for i in 0..pos.len(){
228 if pos[i] >= *self.sizes.get(i).unwrap(){
229 return None;
230 }
231 }
232 let mut index = 0;
233 let mut stride = 1;
234 for i in (0..self.sizes.len()).rev() {
235 index += pos[i] * stride;
236 stride *= self.sizes[i];
237 }
238
239 Some(&self.data[index as usize])
240 }
241 /// changes an element on position
242 ///
243 /// # Example
244 /// ```
245 /// use flashlight_tensor::prelude::*;
246 /// let mut a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
247 ///
248 /// //a =
249 /// //[5.0, 1.0]
250 /// //[1.0, 1.0]
251 /// a.set(5.0, &[0, 0]);
252 ///
253 /// assert_eq!(a.get_data(), &vec!{5.0, 1.0, 1.0, 1.0});
254 /// ```
255 pub fn set(&mut self, value: T, pos: &[u32]){
256 let self_dimensions = self.sizes.len();
257 let selector_dimensions = pos.len();
258 if self_dimensions - selector_dimensions != 0{
259 return;
260 }
261
262 for i in 0..pos.len(){
263 if pos[i] >= *self.sizes.get(i).unwrap(){
264 return;
265 }
266 }
267 let mut index = 0;
268 let mut stride = 1;
269 for i in (0..self.sizes.len()).rev() {
270 index += pos[i] * stride;
271 stride *= self.sizes[i];
272 }
273
274 self.data[index as usize] = value;
275 }
276
277 /// change linear id into global id based on tensor shape
278 ///
279 /// # Example
280 /// ```
281 /// use flashlight_tensor::prelude::*;
282 /// let mut a: Tensor<f32> = Tensor::new(&[5, 1, 2]);
283 ///
284 /// let global_id = a.idx_to_global(3);
285 ///
286 /// assert_eq!(global_id, vec!{1, 0, 1});
287 /// ```
288 pub fn idx_to_global(&self, idx: u32) -> Vec<u32>{
289 idx_to_global(idx, &self.sizes)
290 }
291}
292
293/// change linear id into global id based on shape
294///
295/// # Example
296/// ```
297/// use flashlight_tensor::prelude::*;
298/// let mut a: Tensor<f32> = Tensor::new(&[5, 1, 2]);
299///
300/// let global_id = a.idx_to_global(3);
301///
302/// assert_eq!(global_id, vec!{1, 0, 1});
303/// ```
304pub fn idx_to_global(idx: u32, shape: &[u32]) -> Vec<u32>{
305 if idx>shape.iter().product::<u32>(){
306 return Vec::new();
307 }
308
309 let mut used_id = idx;
310 let mut shape_prod: u32 = shape.iter().product::<u32>();
311 let mut output_vec: Vec<u32> = Vec::with_capacity(shape.len());
312
313 for i in 0..shape.len(){
314 shape_prod = shape_prod/shape[i];
315
316 output_vec.push(used_id/shape_prod);
317 used_id = used_id%shape_prod;
318 }
319
320 output_vec
321}