ferrite/tensor/storage/
traits.rs1use std::sync::Arc;
2use std::{rc::Rc, sync::RwLock};
3use std::cell::RefCell;
4use ndarray::{ArrayBase, Dimension};
5use num_traits::cast::AsPrimitive;
6
7use crate::{CpuStorage};
8
9#[derive(Clone, Copy, Hash, Eq, PartialEq)]
11pub enum Device {
12 Cpu,
13 Cuda,
14 Mps,
15}
16
17pub trait DeviceStorageStatic : DeviceStorage {
18 fn new(data: Vec<f32>, shape: Vec<usize>) -> Self;
19
20 fn new_with_stride(data: Vec<f32>, shape: Vec<usize>, stride: Vec<usize>) -> Self;
21
22 fn create(data: Arc<RwLock<Vec<f32>>>, shape: Vec<usize>, stride: Vec<usize>) -> Self;
23
24 fn compute_strides(shape: &Vec<usize>) -> Vec<usize>;
25}
26
27pub trait DeviceStorageCreation : DeviceStorage {
28 fn zeros(shape: Vec<usize>, device: Option<Device>, requires_grad: Option<bool>) -> Self;
29 fn ones(shape: Vec<usize>, device: Option<Device>, requires_grad: Option<bool>) -> Self;
30 fn from_ndarray<S, D, T>(data: &ArrayBase<S, D>, device: Option<Device>, requires_grad: Option<bool>) -> Self
31 where
32 S: ndarray::Data<Elem = T>,
33 T: AsPrimitive<f32>,
34 D: Dimension;
35
36 fn uniform(l_bound: f32, r_bound: f32, shape: Vec<usize>, device: Option<Device>, requires_grad: Option<bool>) -> Self;
37}
38
39
40pub trait DeviceStorage {
41 fn view(&self, new_shape: Vec<usize>) -> Self where Self: Sized;
42
43 fn data(&self) -> Arc<RwLock<Vec<f32>>>;
44
45 fn data_mut(&self) -> std::sync::RwLockWriteGuard<Vec<f32>>;
46
47 fn set_data(&mut self, data: Vec<f32>);
48
49 fn shape(&self) -> &Vec<usize>;
50
51 fn set_shape(&mut self, shape: Vec<usize>);
52
53 fn stride(&self) -> &Vec<usize>;
54
55 fn set_stride(&mut self, stride: Vec<usize>);
56
57 fn offset(&self) -> usize;
58
59 fn get(&self, indices: &[usize]) -> f32;
60
61 fn set(&mut self, indices: &[usize], value: f32);
62
63 fn make_contiguous(&self) -> (Vec<f32>, i32);
64
65 fn is_contiguous(&self) -> bool;
66}
67
68
69#[derive(Clone)]
70pub enum Storage {
71 Cpu(CpuStorage),
72}