flashlight_tensor/cpu/subtypes/vector.rs
1use crate::tensor::*;
2
3impl<T: Default + Clone> Tensor<T>{
4 /// Get vector from Tensor on position
5 ///
6 /// # Example
7 /// ```
8 /// use flashlight_tensor::prelude::*;
9 ///
10 /// let data: Vec<f32> = vec!{1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
11 /// let sizes: Vec<u32> = vec!{2, 3};
12 ///
13 /// let tensor = Tensor::from_data(&data, &sizes).unwrap();
14 ///
15 /// let vector = tensor.vector(&[0]).unwrap();
16 ///
17 /// let expected_data: Vec<f32> = vec!{1.0, 2.0, 3.0};
18 /// let expected_sizes: Vec<u32> = vec!{3};
19 ///
20 /// assert_eq!(vector.get_data(), &expected_data);
21 /// assert_eq!(vector.get_shape(), &expected_sizes);
22 /// ```
23 pub fn vector(&self, pos: &[u32]) -> Option<Tensor<T>>{
24 let self_dimensions = self.get_shape().len();
25 let selector_dimensions = pos.len();
26 if self_dimensions - selector_dimensions != 1{
27 return None;
28 }
29
30 for i in 0..pos.len(){
31 if pos[i] >= self.get_shape()[i]{
32 return None;
33 }
34 }
35
36 let mut data_begin: u32 = 0;
37
38 let mut stride = self.get_shape()[0];
39
40 for i in 0..pos.len() {
41 data_begin += pos[pos.len() - 1 - i] * stride;
42 stride *= self.get_shape()[1+i];
43 }
44
45 let data_end: u32 = data_begin + self.get_shape().get(self.get_shape().len()-1).unwrap();
46
47 let data = self.get_data()[data_begin as usize..data_end as usize].to_vec();
48 let sizes = self.get_shape()[self.get_shape().len()-1..self.get_shape().len()].to_vec();
49
50 Tensor::from_data(&data, &sizes)
51 }
52}
53
54impl Tensor<f32>{
55 /// Get dot product from tensors if tensors have one dimenstion
56 /// and have same size
57 ///
58 /// # Example
59 /// ```
60 /// use flashlight_tensor::prelude::*;
61 ///
62 /// let data1: Vec<f32> = vec!{1.0, 2.0, 3.0};
63 /// let data2: Vec<f32> = vec!{3.0, 2.0, 1.0};
64 ///
65 /// let expected: f32 = 10.0; // 3.0 + 4.0 + 3.0
66 ///
67 /// let tensor1 = Tensor::from_data(&data1, &[3]).unwrap();
68 /// let tensor2 = Tensor::from_data(&data2, &[3]).unwrap();
69 ///
70 /// let result = tensor1.dot_product(&tensor2).unwrap();
71 ///
72 /// assert_eq!(result, expected);
73 /// ```
74 pub fn dot_product(&self, tens2: &Tensor<f32>) -> Option<f32>{
75 if self.get_shape().len() != 1{
76 return None;
77 }
78 if self.get_shape() != tens2.get_shape(){
79 return None;
80 }
81
82 let mut dot: f32 = 0.0;
83 for i in 0..self.get_shape()[0] as u32{
84 dot += self.value(&[i]).unwrap() * tens2.value(&[i]).unwrap();
85 }
86
87 Some(dot)
88 }
89}