ferrite/tensor/device/cpu/kernels/
transform.rs

1use crate::*;
2
3impl TransformOps for CpuStorage {
4  fn apply_assign<F>(&mut self, op: F)
5  where
6    F: Fn(f32) -> f32,
7  {
8    let data = self.data().read().unwrap().iter()
9      .map(|a| op(*a))
10      .collect();
11
12    self.set_data(data);
13  }
14
15  fn elementwise_op_assign<F>(&mut self, other: &Self, op: F)
16  where
17    F: Fn(f32, f32) -> f32,
18  {
19    let total_elements = self.shape().iter().product();
20    let mut result = vec![0.0; total_elements];
21    
22    // Get data once to avoid multiple borrows
23    let self_binding = self.data();
24    let self_data = self_binding.read().unwrap();
25    let other_binding = other.data();
26    let other_data = other_binding.read().unwrap();
27    
28    // Pre-calculate dimensions for faster access
29    let rank = self.shape().len();
30    let shape = self.shape();
31    let self_strides = self.stride();
32    let other_strides = other.stride();
33    
34    // Use chunk size optimization for contiguous dimensions
35    let mut chunk_size = 1;
36    let mut contiguous_dims = 0;
37    for dim in (0..rank).rev() {
38      if self_strides[dim] == chunk_size && other_strides[dim] == chunk_size {
39        chunk_size *= shape[dim];
40        contiguous_dims += 1;
41      } else {
42        break;
43      }
44    }
45    
46    let outer_dims = rank - contiguous_dims;
47    let mut indices = vec![0; outer_dims];
48    
49    // Process chunks
50    let chunks = total_elements / chunk_size;
51    for chunk_idx in 0..chunks {
52      // Calculate base indices for the chunk
53      let mut self_base_idx = 0;
54      let mut other_base_idx = 0;
55      
56      for (dim, &idx) in indices.iter().enumerate() {
57        self_base_idx += idx * self_strides[dim];
58        other_base_idx += idx * other_strides[dim];
59      }
60      
61      // Process the entire chunk
62      let result_start = chunk_idx * chunk_size;
63      for i in 0..chunk_size {
64        let self_val = self_data[self_base_idx + i];
65        let other_val = other_data[other_base_idx + i];
66        result[result_start + i] = op(self_val, other_val);
67      }
68      
69      // Update indices for outer dimensions
70      for dim in (0..outer_dims).rev() {
71        indices[dim] += 1;
72        if indices[dim] < shape[dim] {
73          break;
74        }
75        indices[dim] = 0;
76      }
77    }
78
79    self.set_data(result);
80  }
81
82  fn reshape(&mut self, new_shape: Vec<usize>) {
83    self.set_shape(new_shape);
84  }
85
86  fn scalar_op_assign<F>(&mut self, scalar: f32, op: F)
87  where
88    F: Fn(f32, f32) -> f32,
89  {
90    let data = self.data().read().unwrap().iter()
91      .map(|a| op(*a, scalar))
92      .collect();
93
94    self.set_data(data);
95  }
96
97  fn permute(&mut self, dims: &[usize]) {
98    let self_shape = self.shape();
99    let shape = dims.iter().map(|&i| self_shape[i]).collect();    
100
101    let self_stride = self.stride();
102    let stride = dims.iter().map(|&i| self_stride[i]).collect();   
103
104    self.set_shape(shape);
105    self.set_stride(stride);
106  }
107
108  fn flatten(&mut self) {
109    let shape: Vec<usize> = vec![self.shape().iter().product()];
110    let stride = vec![1];
111
112    self.set_shape(shape);
113    self.set_stride(stride);
114  }
115
116  fn squeeze(&mut self) {
117    // Remove all 1 dimension from the shape
118    let shape: Vec<usize> = self.shape().to_owned().iter().filter(|&&x| x != 1).cloned().collect();
119    let stride = Self::compute_strides(&shape);
120
121    self.set_shape(shape);
122    self.set_stride(stride);
123  } 
124
125  fn unsqueeze(&mut self, dim: usize) {
126    let mut shape: Vec<usize> = self.shape().to_owned();
127    shape.insert(dim, 1);
128    let stride = Self::compute_strides(&shape);
129
130    self.set_shape(shape);
131    self.set_stride(stride);
132  }
133
134  fn apply<F>(&self, op: F) -> Self
135  where
136    F: Fn(f32) -> f32,
137  {
138    let data = self.data().read().unwrap().iter()
139      .map(|a| op(*a))
140      .collect();
141
142    Self::new(data, self.shape().clone())
143  }
144
145  fn elementwise_op<F>(&self, other: &Self, op: F) -> Self
146  where
147    F: Fn(f32, f32) -> f32,
148  {
149    let total_elements = self.shape().iter().product();
150    let mut result = vec![0.0; total_elements];
151    
152    // Get data once to avoid multiple borrows
153    let self_binding = self.data();
154    let self_data = self_binding.read().unwrap();
155    let other_binding = other.data();
156    let other_data = other_binding.read().unwrap();
157    
158    // Pre-calculate dimensions for faster access
159    let rank = self.shape().len();
160    let shape = self.shape();
161    let self_strides = self.stride();
162    let other_strides = other.stride();
163    
164    // Use chunk size optimization for contiguous dimensions
165    let mut chunk_size = 1;
166    let mut contiguous_dims = 0;
167    for dim in (0..rank).rev() {
168      if self_strides[dim] == chunk_size && other_strides[dim] == chunk_size {
169        chunk_size *= shape[dim];
170        contiguous_dims += 1;
171      } else {
172        break;
173      }
174    }
175    
176    let outer_dims = rank - contiguous_dims;
177    let mut indices = vec![0; outer_dims];
178    
179    // Process chunks
180    let chunks = total_elements / chunk_size;
181    for chunk_idx in 0..chunks {
182      // Calculate base indices for the chunk
183      let mut self_base_idx = 0;
184      let mut other_base_idx = 0;
185      
186      for (dim, &idx) in indices.iter().enumerate() {
187        self_base_idx += idx * self_strides[dim];
188        other_base_idx += idx * other_strides[dim];
189      }
190      
191      // Process the entire chunk
192      let result_start = chunk_idx * chunk_size;
193      for i in 0..chunk_size {
194        let self_val = self_data[self_base_idx + i];
195        let other_val = other_data[other_base_idx + i];
196        result[result_start + i] = op(self_val, other_val);
197      }
198      
199      // Update indices for outer dimensions
200      for dim in (0..outer_dims).rev() {
201        indices[dim] += 1;
202        if indices[dim] < shape[dim] {
203          break;
204        }
205        indices[dim] = 0;
206      }
207    }
208
209    Self::new(result, self.shape().clone())
210  }
211
212  fn scalar_op<F>(&self, scalar: f32, op: F) -> Self
213  where
214    F: Fn(f32, f32) -> f32,
215  {
216    let data = self.data().read().unwrap().iter()
217      .map(|a| op(*a, scalar))
218      .collect();
219
220    Self::new(data, self.shape().clone())
221  }
222
223  fn sum_dim(&self, dims: &[bool]) -> Self {
224    // Handle scalar case
225    if self.shape().len() == 1 && self.shape()[0] == 1 {
226        return self.clone();
227    }
228
229    // Calculate new shape excluding summed dimensions
230    let mut new_shape: Vec<usize> = self.shape().iter()
231        .zip(dims.iter().chain(std::iter::repeat(&false)))
232        .filter_map(|(&dim, &should_sum)| if !should_sum { Some(dim) } else { None })
233        .collect();
234
235    // If all dimensions are summed, return scalar
236    if new_shape.is_empty() {
237        let sum: f32 = self.data().read().unwrap().iter().sum();
238        return Self::new(vec![sum], vec![1]);
239    }
240
241    // Ensure at least one dimension
242    if new_shape.is_empty() {
243        new_shape.push(1);
244    }
245
246    let mut result = vec![0.0; new_shape.iter().product()];
247    
248    // Sum values maintaining non-summed dimensions
249    let mut sum = 0.0;
250    let binding = self.data();
251    let data = binding.read().unwrap();
252    for i in 0..data.len() {
253        sum += data[i];
254    }
255    result[0] = sum;
256
257    Self::new(result, new_shape)
258  }
259
260  fn transpose(&self) -> Self {
261    // Transpose by swapping dimensions & strides
262    if self.shape().len() != 2 { panic!("Must be 2-D Tensor (Matrix)"); }
263
264    let mut shape = self.shape().to_owned();
265    shape.reverse();
266
267    let mut stride = self.stride().to_owned();
268    stride.reverse();
269
270    Self::create(self.data(), shape, stride)
271  }
272
273  fn broadcast(&self, new_shape: &[usize]) -> Self {
274    // Verify broadcast compatibility and get output shape
275    let broadcast_shape = self.compute_broadcast_shape(new_shape);
276    
277    // Calculate new strides for broadcasting
278    let broadcast_strides = self.compute_broadcast_strides(&broadcast_shape);
279
280    Self::create(self.data(), broadcast_shape, broadcast_strides)
281  }
282
283  /// Compute broadcast shape between two shapes
284  fn compute_broadcast_shape(&self, target_shape: &[usize]) -> Vec<usize> {
285    let self_rank = self.shape().len();
286    let target_rank = target_shape.len();
287    let max_rank = std::cmp::max(self_rank, target_rank);
288    
289    // Pad shapes with 1s to match maximum rank
290    let self_padded = self.pad_shape(max_rank);
291    let mut result_shape = Vec::with_capacity(max_rank);
292
293    // Compare dimensions from right to left
294    for i in 0..max_rank {
295      let self_dim = self_padded[i];
296      let target_dim = if i >= max_rank - target_rank {
297          target_shape[i - (max_rank - target_rank)]
298        } else {
299          1
300        };
301
302      if self_dim == target_dim {
303        result_shape.push(self_dim);
304      } else if self_dim == 1 {
305        result_shape.push(target_dim);
306      } else if target_dim == 1 {
307        result_shape.push(self_dim);
308      } else {
309        panic!(
310          "Incompatible broadcast dimensions: {} and {}",
311          self_dim, target_dim
312        )
313      }
314    }
315
316    result_shape
317  }
318
319  /// Compute broadcast strides
320  fn compute_broadcast_strides(&self, broadcast_shape: &[usize]) -> Vec<usize> {
321    let self_rank = self.shape().len();
322    let broadcast_rank = broadcast_shape.len();
323    let rank_diff = broadcast_rank - self_rank;
324    
325    let mut new_strides = vec![0; broadcast_rank];
326    
327    // Fill in strides for dimensions that match or are broadcasted
328    for i in 0..self_rank {
329      let broadcast_idx = i + rank_diff;
330      if broadcast_shape[broadcast_idx] == self.shape()[i] {
331        new_strides[broadcast_idx] = self.stride()[i];
332      } else if self.shape()[i] == 1 {
333        new_strides[broadcast_idx] = 0;  // Stride of 0 for broadcasted dimensions
334      } else {
335        panic!("Invalid broadcast shape")
336      }
337    }
338
339    new_strides
340  }
341
342  /// Pad shape with ones on the left
343  fn pad_shape(&self, target_rank: usize) -> Vec<usize> {
344    let mut padded = vec![1; target_rank];
345    let rank_diff = target_rank - self.shape().len();
346    padded[rank_diff..].copy_from_slice(self.shape());
347    padded
348  }
349
350  fn broadcast_tensors(a: &Self, b: &Self) -> (Self, Self) {
351    let broadcast_shape = a.compute_broadcast_shape(b.shape());
352    let broadcast_a = a.broadcast(&broadcast_shape);
353    let broadcast_b = b.broadcast(&broadcast_shape);
354    (broadcast_a, broadcast_b)
355  }
356}