flashlight_tensor/tensor.rs
1use rand::{Rng};
2
3/// The main Tensor struct
4/// with data and shape order by [... , z, y, x]
5#[derive(Clone)]
6pub struct Tensor<T>{
7 data: Vec<T>,
8 //..., z, y, x
9 shape: Vec<u32>,
10}
11
12impl<T: Default + Clone> Tensor<T>{
13 /// Creates a new tensor with shape
14 /// and default values of each element
15 ///
16 /// # Example
17 /// ```
18 /// use flashlight_tensor::prelude::*;
19 /// //a =
20 /// //[0.0, 0.0]
21 /// //[0.0, 0.0]
22 /// let a: Tensor<f32> = Tensor::new(&[2, 2]);
23 ///
24 /// assert_eq!(a.get_data(), &vec!{0.0, 0.0, 0.0, 0.0});
25 /// ```
26 pub fn new(_shape: &[u32]) -> Tensor<T>{
27 let mut total_size: u32 = 1;
28 for i in 0.._shape.len(){
29 total_size *= _shape[i];
30 }
31
32 Self{
33 data: vec![T::default(); total_size as usize],
34 shape: _shape.to_vec(),
35 }
36 }
37
38 /// Creates a new tensor from data
39 /// with certain size, or None
40 /// if data does not fit in shape
41 ///
42 /// # Example
43 /// ```
44 /// use flashlight_tensor::prelude::*;
45 /// //a =
46 /// //[1.0, 2.0]
47 /// //[3.0, 4.0]
48 /// let a: Tensor<f32> = Tensor::from_data(&vec!{1.0, 2.0, 3.0, 4.0}, &[2, 2]).unwrap();
49 /// assert_eq!(a.get_data(), &vec!{1.0, 2.0, 3.0, 4.0});
50 /// ```
51 pub fn from_data(_data: &[T], _shape: &[u32]) -> Option<Self>{
52 if _shape.iter().product::<u32>() as usize != _data.len(){
53 return None;
54 }
55
56 Some(Self{
57 data: _data.to_vec(),
58 shape: _shape.to_vec(),
59 })
60 }
61
62 /// Creates a new tensor filled
63 /// with one element
64 /// with certain size
65 ///
66 /// # Example
67 /// ```
68 /// use flashlight_tensor::prelude::*;
69 /// //a =
70 /// //[1.0, 1.0]
71 /// //[1.0, 1.0]
72 /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
73 ///
74 /// assert_eq!(a.get_data(), &vec!{1.0, 1.0, 1.0, 1.0});
75 /// ```
76 pub fn fill(fill_data: T, _shape: &[u32]) -> Self{
77 let full_size: u32 = _shape.iter().product();
78
79 Self{
80 data: vec![fill_data; full_size as usize],
81 shape: _shape.to_vec(),
82 }
83 }
84
85 /// Returns reference to data in tensor
86 ///
87 /// # Example
88 /// ```
89 /// use flashlight_tensor::prelude::*;
90 /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
91 ///
92 /// //b = &{1.0, 1.0, 1.0, 1.0}
93 /// let b = a.get_data();
94 ///
95 /// assert_eq!(a.get_data(), &vec!{1.0, 1.0, 1.0, 1.0});
96 /// ```
97 pub fn get_data(&self) -> &Vec<T>{
98 return &self.data;
99 }
100
101 /// Returns reference to shape in tensor
102 ///
103 /// # Example
104 /// ```
105 /// use flashlight_tensor::prelude::*;
106 /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
107 ///
108 /// //b = &{2, 2}
109 /// let b = a.get_shape();
110 ///
111 /// assert_eq!(a.get_shape(), &vec!{2, 2});
112 /// ```
113 pub fn get_shape(&self) -> &Vec<u32>{
114 return &self.shape;
115 }
116 /// returns new tensor with data of first tensor + data of second tensor
117 /// with size[0] = tensor1.size[0] + tensor2.size[0]
118 /// only when tensor1.size[1..] == tensor2.size[1..]
119 ///
120 /// # Example
121 /// ```
122 /// use flashlight_tensor::prelude::*;
123 /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
124 /// let b: Tensor<f32> = Tensor::fill(2.0, &[2, 2]);
125 ///
126 /// //c.data = {1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0}
127 /// //c.shape = {4, 2}
128 /// let c: Tensor<f32> = a.append(&b).unwrap();
129 ///
130 /// assert_eq!(c.get_data(), &vec!{1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0});
131 /// assert_eq!(c.get_shape(), &vec!{4, 2});
132 /// ```
133 pub fn append(&self, tens2: &Tensor<T>) -> Option<Self>{
134 if (self.shape.len() != 1 || tens2.shape.len() != 1) && self.get_shape()[1..].to_vec() != tens2.get_shape()[1..].to_vec(){
135 return None;
136 }
137
138 let mut return_data: Vec<T> = self.get_data().clone();
139 let mut append_data: Vec<T> = tens2.get_data().clone();
140
141 return_data.append(&mut append_data);
142
143 let mut return_shape = self.get_shape().clone();
144 return_shape[0] += tens2.get_shape()[0];
145
146 Some(Self{
147 data: return_data,
148 shape: return_shape,
149 })
150 }
151 /// counts elements in tensor
152 ///
153 /// # Example
154 /// ```
155 /// use flashlight_tensor::prelude::*;
156 /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
157 ///
158 /// //count = 4
159 /// let count = a.count_data();
160 ///
161 /// assert_eq!(a.count_data(), 4);
162 /// ```
163 pub fn count_data(&self) -> usize{
164 self.get_data().len()
165 }
166
167 /// Change the size of tensor if the full size of new_shape is equal to data.len() stored in
168 /// tensor.
169 ///
170 /// # Example
171 /// ```
172 /// use flashlight_tensor::prelude::*;
173 /// let mut a: Tensor<f32> = Tensor::fill(1.0, &[4]);
174 ///
175 /// a.set_shape(&[1, 4]);
176 ///
177 /// assert_eq!(a.get_shape(), &vec!{1, 4});
178 /// ```
179 pub fn set_shape(&mut self, new_shape: &[u32]){
180
181 let shape_prod: u32 = new_shape.iter().product();
182
183 if(shape_prod as usize != self.data.len()){
184 return;
185 }
186
187 self.shape = new_shape.to_vec();
188 }
189
190 /// Change the data of tensor if the new data has length equal to current data length
191 ///
192 /// # Example
193 /// ```
194 /// use flashlight_tensor::prelude::*;
195 /// let mut a: Tensor<f32> = Tensor::fill(1.0, &[4]);
196 ///
197 /// a.set_data(&[2.0, 3.0, 4.0, 5.0]);
198 ///
199 /// assert_eq!(a.get_data(), &vec!{2.0, 3.0, 4.0, 5.0});
200 /// ```
201 pub fn set_data(&mut self, new_data: &[T]){
202 if new_data.len() != self.data.len(){
203 return;
204 }
205
206 self.data = new_data.to_vec();
207 }
208}
209impl<T> Tensor<T>{
210 /// returns an element on position
211 ///
212 /// # Example
213 /// ```
214 /// use flashlight_tensor::prelude::*;
215 /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
216 ///
217 /// //b = 1.0
218 /// let b = a.value(&[0, 0]).unwrap();
219 ///
220 /// assert_eq!(b, &1.0);
221 /// ```
222 pub fn value(&self, pos: &[u32]) -> Option<&T>{
223 let self_dimensions = self.shape.len();
224 let selector_dimensions = pos.len();
225 if self_dimensions - selector_dimensions != 0{
226 return None;
227 }
228
229 for i in 0..pos.len(){
230 if pos[i] >= *self.shape.get(i).unwrap(){
231 return None;
232 }
233 }
234 let mut index = 0;
235 let mut stride = 1;
236 for i in (0..self.shape.len()).rev() {
237 index += pos[i] * stride;
238 stride *= self.shape[i];
239 }
240
241 Some(&self.data[index as usize])
242 }
243 /// changes an element on position
244 ///
245 /// # Example
246 /// ```
247 /// use flashlight_tensor::prelude::*;
248 /// let mut a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
249 ///
250 /// //a =
251 /// //[5.0, 1.0]
252 /// //[1.0, 1.0]
253 /// a.set(5.0, &[0, 0]);
254 ///
255 /// assert_eq!(a.get_data(), &vec!{5.0, 1.0, 1.0, 1.0});
256 /// ```
257 pub fn set(&mut self, value: T, pos: &[u32]){
258 let self_dimensions = self.shape.len();
259 let selector_dimensions = pos.len();
260 if self_dimensions - selector_dimensions != 0{
261 return;
262 }
263
264 for i in 0..pos.len(){
265 if pos[i] >= *self.shape.get(i).unwrap(){
266 return;
267 }
268 }
269 let mut index = 0;
270 let mut stride = 1;
271 for i in (0..self.shape.len()).rev() {
272 index += pos[i] * stride;
273 stride *= self.shape[i];
274 }
275
276 self.data[index as usize] = value;
277 }
278
279 /// change linear id into global id based on tensor shape
280 ///
281 /// # Example
282 /// ```
283 /// use flashlight_tensor::prelude::*;
284 /// let mut a: Tensor<f32> = Tensor::new(&[5, 1, 2]);
285 ///
286 /// let global_id = a.idx_to_global(3);
287 ///
288 /// assert_eq!(global_id, vec!{1, 0, 1});
289 /// ```
290 pub fn idx_to_global(&self, idx: u32) -> Vec<u32>{
291 idx_to_global(idx, &self.shape)
292 }
293}
294
295impl Tensor<f32> {
296
297 /// Creates a new tensor with random data data
298 /// with certain size
299 ///
300 /// # Example
301 /// ```
302 /// use flashlight_tensor::prelude::*;
303 ///
304 /// let a: Tensor<f32> = Tensor::rand(1.0, &[2, 2]);
305 /// ```
306 pub fn rand(rand_range: f32, _shape: &[u32]) -> Self{
307 let full_size: u32 = _shape.iter().product();
308 let mut data: Vec<f32> = Vec::with_capacity(full_size as usize);
309
310 let mut rng = rand::rng();
311
312 for i in 0..full_size{
313 data.push(rng.random_range(-rand_range..rand_range));
314 }
315
316 Self{
317 data,
318 shape: _shape.to_vec(),
319 }
320 }
321}
322
323/// change linear id into global id based on shape
324///
325/// # Example
326/// ```
327/// use flashlight_tensor::prelude::*;
328/// let mut a: Tensor<f32> = Tensor::new(&[5, 1, 2]);
329///
330/// let global_id = a.idx_to_global(3);
331///
332/// assert_eq!(global_id, vec!{1, 0, 1});
333/// ```
334pub fn idx_to_global(idx: u32, shape: &[u32]) -> Vec<u32>{
335 if idx>shape.iter().product::<u32>(){
336 return Vec::new();
337 }
338
339 let mut used_id = idx;
340 let mut shape_prod: u32 = shape.iter().product::<u32>();
341 let mut output_vec: Vec<u32> = Vec::with_capacity(shape.len());
342
343 for i in 0..shape.len(){
344 shape_prod = shape_prod/shape[i];
345
346 output_vec.push(used_id/shape_prod);
347 used_id = used_id%shape_prod;
348 }
349
350 output_vec
351}