ferrite/tensor/device/cpu/storage/
base.rs1use std::sync::{Arc, RwLock};
2use crate::*;
3use ndarray::{ArrayBase, Dimension};
4use num_traits::cast::AsPrimitive;
5use rand::distributions::{Distribution, Uniform};
6
7#[derive(Clone)]
8pub struct CpuStorage {
9 data: Arc<RwLock<Vec<f32>>>,
10 shape: Vec<usize>,
11 stride: Vec<usize>,
12 offset: usize,
13}
14
15impl DeviceStorageStatic for CpuStorage {
16 fn new(data: Vec<f32>, shape: Vec<usize>) -> Self {
17 if data.len() != shape.iter().product::<usize>() {
19 let x: usize = shape.iter().product::<usize>();
20 println!("Data Len: {}. Shape iter prod {}", data.len(), x);
21 println!("Data: {:?}", data);
22 panic!("Data does not match shape!");
23 }
24 let stride = CpuStorage::compute_strides(&shape);
25 CpuStorage {
26 data: Arc::new(RwLock::new(data)),
27 shape: shape,
28 stride: stride,
29 offset: 0,
30 }
31 }
32
33 fn new_with_stride(data: Vec<f32>, shape: Vec<usize>, stride: Vec<usize>) -> Self {
34 if data.len() != shape.iter().product::<usize>() {
35 panic!("Data does not match shape!");
36 }
37 CpuStorage {
38 data: Arc::new(RwLock::new(data)),
39 shape: shape,
40 stride: stride,
41 offset: 0,
42 }
43 }
44
45 fn create(data: Arc<RwLock<Vec<f32>>>, shape: Vec<usize>, stride: Vec<usize>) -> Self {
46 CpuStorage {
47 data: data,
48 shape: shape,
49 stride: stride,
50 offset: 0,
51 }
52 }
53
54 fn compute_strides(shape: &Vec<usize>) -> Vec<usize> {
55 let mut stride = vec![1; shape.len()];
56 for i in (0..shape.len() - 1).rev() {
57 stride[i] = stride[i + 1] * shape[i + 1];
58 }
59 stride
60 }
61}
62
63impl DeviceStorageCreation for CpuStorage {
64 fn zeros(shape: Vec<usize>, _device: Option<Device>, _requires_grad: Option<bool>) -> Self {
65 let size = shape.iter().product();
66 let data = vec![0.0; size];
67 CpuStorage::new(data, shape)
68 }
69
70 fn ones(shape: Vec<usize>, _device: Option<Device>, _requires_grad: Option<bool>) -> Self {
71 let size = shape.iter().product();
72 let data = vec![1.0; size];
73 CpuStorage::new(data, shape)
74 }
75
76 fn from_ndarray<S, D, T>(
77 data: &ArrayBase<S, D>,
78 _device: Option<Device>,
79 _requires_grad: Option<bool>,
80 ) -> Self
81 where
82 S: ndarray::Data<Elem = T>,
83 T: AsPrimitive<f32>,
84 D: Dimension,
85 {
86 let shape = data.shape().to_vec();
87 let arr = data.mapv(|x| x.as_());
88 let data = arr.iter().cloned().collect();
89 CpuStorage::new(data, shape)
90 }
91
92 fn uniform(
93 l_bound: f32,
94 r_bound: f32,
95 shape: Vec<usize>,
96 _device: Option<Device>,
97 _requires_grad: Option<bool>,
98 ) -> Self {
99 let uniform = Uniform::from(l_bound..r_bound); let mut rng = rand::thread_rng(); let data = (0..shape.iter().product())
102 .map(|_| uniform.sample(&mut rng)) .collect();
104 CpuStorage::new(data, shape)
105 }
106}
107
108impl DeviceStorage for CpuStorage {
109 fn view(&self, new_shape: Vec<usize>) -> Self {
110 let total_elements: usize = new_shape.iter().product();
112 if total_elements != self.shape.iter().product::<usize>() {
113 panic!("New shape must have the same number of elements");
114 }
115 let stride = CpuStorage::compute_strides(&new_shape);
116 CpuStorage {
117 data: Arc::clone(&self.data),
118 shape: new_shape,
119 stride: stride,
120 offset: self.offset,
121 }
122 }
123
124 fn data(&self) -> Arc<RwLock<Vec<f32>>> {
125 Arc::clone(&self.data)
126 }
127
128 fn data_mut(&self) -> std::sync::RwLockWriteGuard<Vec<f32>> {
129 self.data.write().unwrap()
130 }
131
132 fn set_data(&mut self, data: Vec<f32>) {
133 self.data = Arc::new(RwLock::new(data));
134 }
135
136 fn shape(&self) -> &Vec<usize> {
137 &self.shape
138 }
139
140 fn set_shape(&mut self, shape: Vec<usize>) {
141 self.shape = shape;
142 }
143
144 fn stride(&self) -> &Vec<usize> {
145 &self.stride
146 }
147
148 fn set_stride(&mut self, stride: Vec<usize>) {
149 self.stride = stride;
150 }
151
152 fn offset(&self) -> usize {
153 self.offset
154 }
155
156 fn get(&self, indices: &[usize]) -> f32 {
157 if indices.len() != self.shape.len() {
159 panic!("Tensor index does not match shape!");
160 }
161 let mut flat_index = 0;
163 for (i, &idx) in indices.iter().enumerate() {
164 if idx >= self.shape[i] {
165 panic!("Tensor index out of bounds!");
166 }
167 flat_index += idx * self.stride[i];
168 }
169 let data = self.data.read().unwrap();
171 data[flat_index]
172 }
173
174 fn set(&mut self, indices: &[usize], value: f32) {
175 if indices.len() != self.shape.len() {
176 panic!("Tensor index does not match shape!");
177 }
178 let mut flat_index = 0;
179 for (i, &idx) in indices.iter().enumerate() {
180 if idx >= self.shape[i] {
181 panic!("Tensor index out of bounds!");
182 }
183 flat_index += idx * self.stride[i];
184 }
185 let mut data = self.data.write().unwrap();
187 data[flat_index] = value;
188 }
189
190 fn make_contiguous(&self) -> (Vec<f32>, i32) {
191 if self.is_contiguous() {
192 return (self.data.read().unwrap().clone(), self.shape[1] as i32);
193 }
194 let mut contiguous = vec![0.0; self.shape.iter().product()];
195 for i in 0..self.shape[0] {
196 for j in 0..self.shape[1] {
197 contiguous[i * self.shape[1] + j] = self.get(&[i, j]);
198 }
199 }
200 (contiguous, self.shape[1] as i32)
201 }
202
203 fn is_contiguous(&self) -> bool {
204 let mut expected_stride = 1;
205 for i in (0..self.shape.len()).rev() {
206 if self.stride[i] != expected_stride {
207 return false;
208 }
209 expected_stride *= self.shape[i];
210 }
211 true
212 }
213}