ferrite/tensor/device/cpu/storage/
base.rs

1use std::sync::{Arc, RwLock};
2use crate::*;
3use ndarray::{ArrayBase, Dimension};
4use num_traits::cast::AsPrimitive;
5use rand::distributions::{Distribution, Uniform};
6
7#[derive(Clone)]
8pub struct CpuStorage {
9    data: Arc<RwLock<Vec<f32>>>,
10    shape: Vec<usize>,
11    stride: Vec<usize>,
12    offset: usize,
13}
14
15impl DeviceStorageStatic for CpuStorage {
16    fn new(data: Vec<f32>, shape: Vec<usize>) -> Self {
17        // Check that the data length matches the product of shape dimensions.
18        if data.len() != shape.iter().product::<usize>() {
19            let x: usize = shape.iter().product::<usize>();
20            println!("Data Len: {}. Shape iter prod {}", data.len(), x);
21            println!("Data: {:?}", data);
22            panic!("Data does not match shape!");
23        }
24        let stride = CpuStorage::compute_strides(&shape);
25        CpuStorage {
26            data: Arc::new(RwLock::new(data)),
27            shape: shape,
28            stride: stride,
29            offset: 0,
30        }
31    }
32
33    fn new_with_stride(data: Vec<f32>, shape: Vec<usize>, stride: Vec<usize>) -> Self {
34        if data.len() != shape.iter().product::<usize>() {
35            panic!("Data does not match shape!");
36        }
37        CpuStorage {
38            data: Arc::new(RwLock::new(data)),
39            shape: shape,
40            stride: stride,
41            offset: 0,
42        }
43    }
44
45    fn create(data: Arc<RwLock<Vec<f32>>>, shape: Vec<usize>, stride: Vec<usize>) -> Self {
46        CpuStorage {
47            data: data,
48            shape: shape,
49            stride: stride,
50            offset: 0,
51        }
52    }
53
54    fn compute_strides(shape: &Vec<usize>) -> Vec<usize> {
55        let mut stride = vec![1; shape.len()];
56        for i in (0..shape.len() - 1).rev() {
57            stride[i] = stride[i + 1] * shape[i + 1];
58        }
59        stride
60    }
61}
62
63impl DeviceStorageCreation for CpuStorage {
64    fn zeros(shape: Vec<usize>, _device: Option<Device>, _requires_grad: Option<bool>) -> Self {
65        let size = shape.iter().product();
66        let data = vec![0.0; size];
67        CpuStorage::new(data, shape)
68    }
69
70    fn ones(shape: Vec<usize>, _device: Option<Device>, _requires_grad: Option<bool>) -> Self {
71        let size = shape.iter().product();
72        let data = vec![1.0; size];
73        CpuStorage::new(data, shape)
74    }
75
76    fn from_ndarray<S, D, T>(
77        data: &ArrayBase<S, D>,
78        _device: Option<Device>,
79        _requires_grad: Option<bool>,
80    ) -> Self
81    where
82        S: ndarray::Data<Elem = T>,
83        T: AsPrimitive<f32>,
84        D: Dimension,
85    {
86        let shape = data.shape().to_vec();
87        let arr = data.mapv(|x| x.as_());
88        let data = arr.iter().cloned().collect();
89        CpuStorage::new(data, shape)
90    }
91
92    fn uniform(
93        l_bound: f32,
94        r_bound: f32,
95        shape: Vec<usize>,
96        _device: Option<Device>,
97        _requires_grad: Option<bool>,
98    ) -> Self {
99        let uniform = Uniform::from(l_bound..r_bound); // Create a uniform distribution
100        let mut rng = rand::thread_rng(); // Random number generator
101        let data = (0..shape.iter().product())
102            .map(|_| uniform.sample(&mut rng)) // Sample from the uniform distribution
103            .collect();
104        CpuStorage::new(data, shape)
105    }
106}
107
108impl DeviceStorage for CpuStorage {
109    fn view(&self, new_shape: Vec<usize>) -> Self {
110        // Check if the new shape is compatible.
111        let total_elements: usize = new_shape.iter().product();
112        if total_elements != self.shape.iter().product::<usize>() {
113            panic!("New shape must have the same number of elements");
114        }
115        let stride = CpuStorage::compute_strides(&new_shape);
116        CpuStorage {
117            data: Arc::clone(&self.data),
118            shape: new_shape,
119            stride: stride,
120            offset: self.offset,
121        }
122    }
123
124    fn data(&self) -> Arc<RwLock<Vec<f32>>> {
125        Arc::clone(&self.data)
126    }
127
128    fn data_mut(&self) -> std::sync::RwLockWriteGuard<Vec<f32>> {
129        self.data.write().unwrap()
130    }
131
132    fn set_data(&mut self, data: Vec<f32>) {
133        self.data = Arc::new(RwLock::new(data));
134    }
135
136    fn shape(&self) -> &Vec<usize> {
137        &self.shape
138    }
139
140    fn set_shape(&mut self, shape: Vec<usize>) {
141        self.shape = shape;
142    }
143
144    fn stride(&self) -> &Vec<usize> {
145        &self.stride
146    }
147
148    fn set_stride(&mut self, stride: Vec<usize>) {
149        self.stride = stride;
150    }
151
152    fn offset(&self) -> usize {
153        self.offset
154    }
155
156    fn get(&self, indices: &[usize]) -> f32 {
157        // Ensure the number of indices matches the tensor's dimensions.
158        if indices.len() != self.shape.len() {
159            panic!("Tensor index does not match shape!");
160        }
161        // Compute the flat index.
162        let mut flat_index = 0;
163        for (i, &idx) in indices.iter().enumerate() {
164            if idx >= self.shape[i] {
165                panic!("Tensor index out of bounds!");
166            }
167            flat_index += idx * self.stride[i];
168        }
169        // Use a read lock for safe concurrent access.
170        let data = self.data.read().unwrap();
171        data[flat_index]
172    }
173
174    fn set(&mut self, indices: &[usize], value: f32) {
175        if indices.len() != self.shape.len() {
176            panic!("Tensor index does not match shape!");
177        }
178        let mut flat_index = 0;
179        for (i, &idx) in indices.iter().enumerate() {
180            if idx >= self.shape[i] {
181                panic!("Tensor index out of bounds!");
182            }
183            flat_index += idx * self.stride[i];
184        }
185        // Acquire a write lock for mutation.
186        let mut data = self.data.write().unwrap();
187        data[flat_index] = value;
188    }
189
190    fn make_contiguous(&self) -> (Vec<f32>, i32) {
191        if self.is_contiguous() {
192            return (self.data.read().unwrap().clone(), self.shape[1] as i32);
193        }
194        let mut contiguous = vec![0.0; self.shape.iter().product()];
195        for i in 0..self.shape[0] {
196            for j in 0..self.shape[1] {
197                contiguous[i * self.shape[1] + j] = self.get(&[i, j]);
198            }
199        }
200        (contiguous, self.shape[1] as i32)
201    }
202
203    fn is_contiguous(&self) -> bool {
204        let mut expected_stride = 1;
205        for i in (0..self.shape.len()).rev() {
206            if self.stride[i] != expected_stride {
207                return false;
208            }
209            expected_stride *= self.shape[i];
210        }
211        true
212    }
213}