flashlight_tensor/
tensor.rs

1use rand::{Rng};
2
3/// The main Tensor struct 
4/// with data and shape order by [... , z, y, x]
5#[derive(Clone)]
6pub struct Tensor<T>{
7    data: Vec<T>,
8    //..., z, y, x
9    shape: Vec<u32>,
10}
11
12impl<T: Default + Clone> Tensor<T>{
13    /// Creates a new tensor with shape
14    /// and default values of each element
15    ///
16    /// # Example
17    /// ```
18    /// use flashlight_tensor::prelude::*;
19    /// //a =
20    /// //[0.0, 0.0]
21    /// //[0.0, 0.0]
22    /// let a: Tensor<f32> = Tensor::new(&[2, 2]);
23    ///
24    /// assert_eq!(a.get_data(), &vec!{0.0, 0.0, 0.0, 0.0});
25    /// ```
26    pub fn new(_shape: &[u32]) -> Tensor<T>{
27        let mut total_size: u32 = 1;
28        for i in 0.._shape.len(){
29            total_size *= _shape[i];
30        }
31        
32        Self{
33            data: vec![T::default(); total_size as usize],
34            shape: _shape.to_vec(),
35        }
36    }
37
38    /// Creates a new tensor from data
39    /// with certain size, or None
40    /// if data does not fit in shape
41    ///
42    /// # Example
43    /// ```
44    /// use flashlight_tensor::prelude::*;
45    /// //a =
46    /// //[1.0, 2.0]
47    /// //[3.0, 4.0]
48    /// let a: Tensor<f32> = Tensor::from_data(&vec!{1.0, 2.0, 3.0, 4.0}, &[2, 2]).unwrap();
49    /// assert_eq!(a.get_data(), &vec!{1.0, 2.0, 3.0, 4.0});
50    /// ```
51    pub fn from_data(_data: &[T], _shape: &[u32]) -> Option<Self>{
52        if _shape.iter().product::<u32>() as usize != _data.len(){
53            return None;
54        }
55
56        Some(Self{
57            data: _data.to_vec(),
58            shape: _shape.to_vec(),
59        })
60    }
61    
62    /// Creates a new tensor filled
63    /// with one element
64    /// with certain size
65    ///
66    /// # Example
67    /// ```
68    /// use flashlight_tensor::prelude::*;
69    /// //a = 
70    /// //[1.0, 1.0]
71    /// //[1.0, 1.0]
72    /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
73    ///
74    /// assert_eq!(a.get_data(), &vec!{1.0, 1.0, 1.0, 1.0});
75    /// ```
76    pub fn fill(fill_data: T, _shape: &[u32]) -> Self{
77        let full_size: u32 = _shape.iter().product();
78        
79        Self{
80            data: vec![fill_data; full_size as usize],
81            shape: _shape.to_vec(),
82        }
83    }
84
85    /// Returns reference to data in tensor
86    /// 
87    /// # Example
88    /// ```
89    /// use flashlight_tensor::prelude::*;
90    /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
91    ///
92    /// //b = &{1.0, 1.0, 1.0, 1.0}
93    /// let b = a.get_data();
94    ///
95    /// assert_eq!(a.get_data(), &vec!{1.0, 1.0, 1.0, 1.0});
96    /// ```
97    pub fn get_data(&self) -> &Vec<T>{
98        return &self.data;
99    }
100
101    /// Returns reference to shape in tensor
102    /// 
103    /// # Example
104    /// ```
105    /// use flashlight_tensor::prelude::*;
106    /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
107    ///
108    /// //b = &{2, 2}
109    /// let b = a.get_shape();
110    ///
111    /// assert_eq!(a.get_shape(), &vec!{2, 2});
112    /// ```
113    pub fn get_shape(&self) -> &Vec<u32>{
114        return &self.shape;
115    }
116    /// returns new tensor with data of first tensor + data of second tensor
117    /// with size[0] = tensor1.size[0] + tensor2.size[0]
118    /// only when tensor1.size[1..] == tensor2.size[1..]
119    ///
120    /// # Example
121    /// ```
122    /// use flashlight_tensor::prelude::*;
123    /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
124    /// let b: Tensor<f32> = Tensor::fill(2.0, &[2, 2]);
125    ///
126    /// //c.data = {1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0}
127    /// //c.shape = {4, 2}
128    /// let c: Tensor<f32> = a.append(&b).unwrap();
129    ///
130    /// assert_eq!(c.get_data(), &vec!{1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0});
131    /// assert_eq!(c.get_shape(), &vec!{4, 2});
132    /// ```
133    pub fn append(&self, tens2: &Tensor<T>) -> Option<Self>{
134        if (self.shape.len() != 1 || tens2.shape.len() != 1) && self.get_shape()[1..].to_vec() != tens2.get_shape()[1..].to_vec(){
135            return None;
136        }
137
138        let mut return_data: Vec<T> = self.get_data().clone();
139        let mut append_data: Vec<T> = tens2.get_data().clone();
140        
141        return_data.append(&mut append_data);
142
143        let mut return_shape = self.get_shape().clone();
144        return_shape[0] += tens2.get_shape()[0];
145
146        Some(Self{
147            data: return_data,
148            shape: return_shape,
149        })
150    }
151    /// counts elements in tensor
152    ///
153    /// # Example
154    /// ```
155    /// use flashlight_tensor::prelude::*;
156    /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
157    ///
158    /// //count = 4
159    /// let count = a.count_data();
160    ///
161    /// assert_eq!(a.count_data(), 4);
162    /// ```
163    pub fn count_data(&self) -> usize{
164        self.get_data().len()
165    }
166    
167    /// Change the size of tensor if the full size of new_shape is equal to data.len() stored in
168    /// tensor.
169    ///
170    /// # Example
171    /// ```
172    /// use flashlight_tensor::prelude::*;
173    /// let mut a: Tensor<f32> = Tensor::fill(1.0, &[4]);
174    ///
175    /// a.set_shape(&[1, 4]);
176    ///
177    /// assert_eq!(a.get_shape(), &vec!{1, 4});
178    /// ```
179    pub fn set_shape(&mut self, new_shape: &[u32]){
180        
181        let shape_prod: u32 = new_shape.iter().product();
182
183        if(shape_prod as usize != self.data.len()){
184            return;
185        }
186
187        self.shape = new_shape.to_vec();
188    }
189
190    /// Change the data of tensor if the new data has length equal to current data length
191    ///
192    /// # Example
193    /// ```
194    /// use flashlight_tensor::prelude::*;
195    /// let mut a: Tensor<f32> = Tensor::fill(1.0, &[4]);
196    ///
197    /// a.set_data(&[2.0, 3.0, 4.0, 5.0]);
198    ///
199    /// assert_eq!(a.get_data(), &vec!{2.0, 3.0, 4.0, 5.0});
200    /// ```
201    pub fn set_data(&mut self, new_data: &[T]){
202        if new_data.len() != self.data.len(){
203            return;
204        }
205
206        self.data = new_data.to_vec();
207    }
208}
209impl<T> Tensor<T>{
210    /// returns an element on position
211    ///
212    /// # Example
213    /// ```
214    /// use flashlight_tensor::prelude::*;
215    /// let a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
216    ///
217    /// //b = 1.0
218    /// let b = a.value(&[0, 0]).unwrap();
219    ///
220    /// assert_eq!(b, &1.0);
221    /// ```
222    pub fn value(&self, pos: &[u32]) -> Option<&T>{
223        let self_dimensions = self.shape.len();
224        let selector_dimensions = pos.len();
225        if self_dimensions - selector_dimensions != 0{
226            return None;
227        }
228        
229        for i in 0..pos.len(){
230            if pos[i] >= *self.shape.get(i).unwrap(){
231                return None;
232            }
233        }
234        let mut index = 0;
235        let mut stride = 1;
236        for i in (0..self.shape.len()).rev() {
237            index += pos[i] * stride;
238            stride *= self.shape[i];
239        }
240
241        Some(&self.data[index as usize])
242    }
243    /// changes an element on position
244    ///
245    /// # Example
246    /// ```
247    /// use flashlight_tensor::prelude::*;
248    /// let mut a: Tensor<f32> = Tensor::fill(1.0, &[2, 2]);
249    ///
250    /// //a =
251    /// //[5.0, 1.0]
252    /// //[1.0, 1.0]
253    /// a.set(5.0, &[0, 0]);
254    ///
255    /// assert_eq!(a.get_data(), &vec!{5.0, 1.0, 1.0, 1.0});
256    /// ```
257    pub fn set(&mut self, value: T, pos: &[u32]){
258        let self_dimensions = self.shape.len();
259        let selector_dimensions = pos.len();
260        if self_dimensions - selector_dimensions != 0{
261            return;
262        }
263        
264        for i in 0..pos.len(){
265            if pos[i] >= *self.shape.get(i).unwrap(){
266                return;
267            }
268        }
269        let mut index = 0;
270        let mut stride = 1;
271        for i in (0..self.shape.len()).rev() {
272            index += pos[i] * stride;
273            stride *= self.shape[i];
274        }
275
276        self.data[index as usize] = value;
277    }
278
279    /// change linear id into global id based on tensor shape
280    ///
281    /// # Example
282    /// ```
283    /// use flashlight_tensor::prelude::*;
284    /// let mut a: Tensor<f32> = Tensor::new(&[5, 1, 2]);
285    ///
286    /// let global_id = a.idx_to_global(3);
287    ///
288    /// assert_eq!(global_id, vec!{1, 0, 1});
289    /// ```
290    pub fn idx_to_global(&self, idx: u32) -> Vec<u32>{
291        idx_to_global(idx, &self.shape)
292    }
293}
294
295impl Tensor<f32> {
296
297    /// Creates a new tensor with random data data
298    /// with certain size
299    ///
300    /// # Example
301    /// ```
302    /// use flashlight_tensor::prelude::*;
303    ///
304    /// let a: Tensor<f32> = Tensor::rand(1.0, &[2, 2]);
305    /// ```
306    pub fn rand(rand_range: f32, _shape: &[u32]) -> Self{
307        let full_size: u32 = _shape.iter().product();
308        let mut data: Vec<f32> = Vec::with_capacity(full_size as usize);
309        
310        let mut rng = rand::rng();
311
312        for i in 0..full_size{
313            data.push(rng.random_range(-rand_range..rand_range));
314        }
315
316        Self{
317            data,
318            shape: _shape.to_vec(),
319        }
320    }
321}
322
323/// change linear id into global id based on shape
324///
325/// # Example
326/// ```
327/// use flashlight_tensor::prelude::*;
328/// let mut a: Tensor<f32> = Tensor::new(&[5, 1, 2]);
329///
330/// let global_id = a.idx_to_global(3);
331///
332/// assert_eq!(global_id, vec!{1, 0, 1});
333/// ```
334pub fn idx_to_global(idx: u32, shape: &[u32]) -> Vec<u32>{
335    if idx>shape.iter().product::<u32>(){
336        return Vec::new();
337    }
338
339    let mut used_id = idx;
340    let mut shape_prod: u32 = shape.iter().product::<u32>();
341    let mut output_vec: Vec<u32> = Vec::with_capacity(shape.len());
342
343    for i in 0..shape.len(){
344        shape_prod = shape_prod/shape[i];
345
346        output_vec.push(used_id/shape_prod);
347        used_id = used_id%shape_prod;
348    }
349
350    output_vec
351}