flashlight_tensor/
tensor.rs

1/// The main Tensor struct 
2/// with data and sizes order by [... , z, y, x]
3#[derive(Clone)]
4pub struct Tensor<T>{
5    data: Vec<T>,
6    //..., z, y, x
7    sizes: Vec<u32>,
8}
9
10impl<T: Default + Clone> Tensor<T>{
11    /// Creates a new tensor with sizes
12    /// and default values of each element
13    ///
14    /// # Example
15    /// ```
16    /// use flashlight_tensor::prelude::*;
17    /// //a =
18    /// //[0.0, 0.0]
19    /// //[0.0, 0.0]
20    /// let a: Tensor<f32> = Tensor::new(&[2, 2]);
21    ///
22    /// assert_eq!(a.get_data(), &vec!{0.0, 0.0, 0.0, 0.0});
23    /// ```
24    pub fn new(_sizes: &[u32]) -> Tensor<T>{
25        let mut total_size: u32 = 1;
26        for i in 0.._sizes.len(){
27            total_size *= _sizes[i];
28        }
29        
30        Self{
31            data: vec![T::default(); total_size as usize],
32            sizes: _sizes.to_vec(),
33        }
34    }
35
36    /// Creates a new tensor from data
37    /// with certain size, or None
38    /// if data does not fit in sizes
39    ///
40    /// # Example
41    /// ```
42    /// use flashlight_tensor::prelude::*;
43    /// //a =
44    /// //[1.0, 2.0]
45    /// //[3.0, 4.0]
46    /// let a: Tensor<f32> = Tensor::from_data(&vec!{1.0, 2.0, 3.0, 4.0}, &[2, 2]).unwrap();
47    /// assert_eq!(a.get_data(), &vec!{1.0, 2.0, 3.0, 4.0});
48    /// ```
49    pub fn from_data(_data: &[T], _sizes: &[u32]) -> Option<Self>{
50        if _sizes.iter().product::<u32>() as usize != _data.len(){
51            return None;
52        }
53
54        Some(Self{
55            data: _data.to_vec(),
56            sizes: _sizes.to_vec(),
57        })
58    }
59    
60    /// Creates a new tensor filled
61    /// with one element
62    /// with certain size
63    ///
64    /// # Example
65    /// ```
66    /// use flashlight_tensor::prelude::*;
67    /// //a = 
68    /// //[1.0, 1.0]
69    /// //[1.0, 1.0]
70    /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
71    ///
72    /// assert_eq!(a.get_data(), &vec!{1.0, 1.0, 1.0, 1.0});
73    /// ```
74    pub fn fill(fill_data: T, _sizes: &[u32]) -> Self{
75        let full_size: u32 = _sizes.iter().product();
76        
77        Self{
78            data: vec![fill_data; full_size as usize],
79            sizes: _sizes.to_vec(),
80        }
81    }
82
83    /// Returns reference to data in tensor
84    /// 
85    /// # Example
86    /// ```
87    /// use flashlight_tensor::prelude::*;
88    /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
89    ///
90    /// //b = &{1.0, 1.0, 1.0, 1.0}
91    /// let b = a.get_data();
92    ///
93    /// assert_eq!(a.get_data(), &vec!{1.0, 1.0, 1.0, 1.0});
94    /// ```
95    pub fn get_data(&self) -> &Vec<T>{
96        return &self.data;
97    }
98
99    /// Returns reference to sizes in tensor
100    /// 
101    /// # Example
102    /// ```
103    /// use flashlight_tensor::prelude::*;
104    /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
105    ///
106    /// //b = &{2, 2}
107    /// let b = a.get_sizes();
108    ///
109    /// assert_eq!(a.get_sizes(), &vec!{2, 2});
110    /// ```
111    pub fn get_sizes(&self) -> &Vec<u32>{
112        return &self.sizes;
113    }
114    /// returns new tensor with data of first tensor + data of second tensor
115    /// with size[0] = tensor1.size[0] + tensor2.size[0]
116    /// only when tensor1.size[1..] == tensor2.size[1..]
117    ///
118    /// # Example
119    /// ```
120    /// use flashlight_tensor::prelude::*;
121    /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
122    /// let b: Tensor<f32> = Tensor::fill(2.0, &[2, 2]);
123    ///
124    /// //c.data = {1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0}
125    /// //c.sizes = {4, 2}
126    /// let c: Tensor<f32> = a.append(&b).unwrap();
127    ///
128    /// assert_eq!(c.get_data(), &vec!{1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0});
129    /// assert_eq!(c.get_sizes(), &vec!{4, 2});
130    /// ```
131    pub fn append(&self, tens2: &Tensor<T>) -> Option<Self>{
132        if (self.sizes.len() != 1 || tens2.sizes.len() != 1) && self.get_sizes()[1..].to_vec() != tens2.get_sizes()[1..].to_vec(){
133            return None;
134        }
135
136        let mut return_data: Vec<T> = self.get_data().clone();
137        let mut append_data: Vec<T> = tens2.get_data().clone();
138        
139        return_data.append(&mut append_data);
140
141        let mut return_sizes = self.get_sizes().clone();
142        return_sizes[0] += tens2.get_sizes()[0];
143
144        Some(Self{
145            data: return_data,
146            sizes: return_sizes,
147        })
148    }
149    /// counts elements in tensor
150    ///
151    /// # Example
152    /// ```
153    /// use flashlight_tensor::prelude::*;
154    /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
155    ///
156    /// //count = 4
157    /// let count = a.count_data();
158    ///
159    /// assert_eq!(a.count_data(), 4);
160    /// ```
161    pub fn count_data(&self) -> usize{
162        self.get_data().len()
163    }
164    
165    /// Change the size of tensor if the full size of new_sizes is equal to data.len() stored in
166    /// tensor.
167    ///
168    /// # Example
169    /// ```
170    /// use flashlight_tensor::prelude::*;
171    /// let mut a: Tensor<f32> = Tensor::fill(1.0, &[4]);
172    ///
173    /// a.set_size(&[1, 4]);
174    ///
175    /// assert_eq!(a.get_sizes(), &vec!{1, 4});
176    /// ```
177    pub fn set_size(&mut self, new_sizes: &[u32]){
178        
179        let sizes_prod: u32 = new_sizes.iter().product();
180
181        if(sizes_prod as usize != self.data.len()){
182            return;
183        }
184
185        self.sizes = new_sizes.to_vec();
186    }
187
188    /// Change the data of tensor if the new data has length equal to current data length
189    ///
190    /// # Example
191    /// ```
192    /// use flashlight_tensor::prelude::*;
193    /// let mut a: Tensor<f32> = Tensor::fill(1.0, &[4]);
194    ///
195    /// a.set_data(&[2.0, 3.0, 4.0, 5.0]);
196    ///
197    /// assert_eq!(a.get_data(), &vec!{2.0, 3.0, 4.0, 5.0});
198    /// ```
199    pub fn set_data(&mut self, new_data: &[T]){
200        if new_data.len() != self.data.len(){
201            return;
202        }
203
204        self.data = new_data.to_vec();
205    }
206}
207impl<T> Tensor<T>{
208    /// returns an element on position
209    ///
210    /// # Example
211    /// ```
212    /// use flashlight_tensor::prelude::*;
213    /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
214    ///
215    /// //b = 1.0
216    /// let b = a.value(&[0, 0]).unwrap();
217    ///
218    /// assert_eq!(b, &1.0);
219    /// ```
220    pub fn value(&self, pos: &[u32]) -> Option<&T>{
221        let self_dimensions = self.sizes.len();
222        let selector_dimensions = pos.len();
223        if self_dimensions - selector_dimensions != 0{
224            return None;
225        }
226        
227        for i in 0..pos.len(){
228            if pos[i] >= *self.sizes.get(i).unwrap(){
229                return None;
230            }
231        }
232        let mut index = 0;
233        let mut stride = 1;
234        for i in (0..self.sizes.len()).rev() {
235            index += pos[i] * stride;
236            stride *= self.sizes[i];
237        }
238
239        Some(&self.data[index as usize])
240    }
241    /// changes an element on position
242    ///
243    /// # Example
244    /// ```
245    /// use flashlight_tensor::prelude::*;
246    /// let mut a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
247    ///
248    /// //a =
249    /// //[5.0, 1.0]
250    /// //[1.0, 1.0]
251    /// a.set(5.0, &[0, 0]);
252    ///
253    /// assert_eq!(a.get_data(), &vec!{5.0, 1.0, 1.0, 1.0});
254    /// ```
255    pub fn set(&mut self, value: T, pos: &[u32]){
256        let self_dimensions = self.sizes.len();
257        let selector_dimensions = pos.len();
258        if self_dimensions - selector_dimensions != 0{
259            return;
260        }
261        
262        for i in 0..pos.len(){
263            if pos[i] >= *self.sizes.get(i).unwrap(){
264                return;
265            }
266        }
267        let mut index = 0;
268        let mut stride = 1;
269        for i in (0..self.sizes.len()).rev() {
270            index += pos[i] * stride;
271            stride *= self.sizes[i];
272        }
273
274        self.data[index as usize] = value;
275    }
276
277    /// change linear id into global id based on tensor shape
278    ///
279    /// # Example
280    /// ```
281    /// use flashlight_tensor::prelude::*;
282    /// let mut a: Tensor<f32> = Tensor::new(&[5, 1, 2]);
283    ///
284    /// let global_id = a.idx_to_global(3);
285    ///
286    /// assert_eq!(global_id, vec!{1, 0, 1});
287    /// ```
288    pub fn idx_to_global(&self, idx: u32) -> Vec<u32>{
289        idx_to_global(idx, &self.sizes)
290    }
291}
292
293/// change linear id into global id based on shape
294///
295/// # Example
296/// ```
297/// use flashlight_tensor::prelude::*;
298/// let mut a: Tensor<f32> = Tensor::new(&[5, 1, 2]);
299///
300/// let global_id = a.idx_to_global(3);
301///
302/// assert_eq!(global_id, vec!{1, 0, 1});
303/// ```
304pub fn idx_to_global(idx: u32, shape: &[u32]) -> Vec<u32>{
305    if idx>shape.iter().product::<u32>(){
306        return Vec::new();
307    }
308
309    let mut used_id = idx;
310    let mut shape_prod: u32 = shape.iter().product::<u32>();
311    let mut output_vec: Vec<u32> = Vec::with_capacity(shape.len());
312
313    for i in 0..shape.len(){
314        shape_prod = shape_prod/shape[i];
315
316        output_vec.push(used_id/shape_prod);
317        used_id = used_id%shape_prod;
318    }
319
320    output_vec
321}