ferrite/tensor/storage/
traits.rs

1use std::sync::Arc;
2use std::{rc::Rc, sync::RwLock};
3use std::cell::RefCell;
4use ndarray::{ArrayBase, Dimension};
5use num_traits::cast::AsPrimitive;
6
7use crate::{CpuStorage};
8
9// Device types
10#[derive(Clone, Copy, Hash, Eq, PartialEq)]
11pub enum Device {
12  Cpu,
13  Cuda,
14  Mps,
15}
16
17pub trait DeviceStorageStatic : DeviceStorage {
18  fn new(data: Vec<f32>, shape: Vec<usize>) -> Self;
19
20  fn new_with_stride(data: Vec<f32>, shape: Vec<usize>, stride: Vec<usize>) -> Self;
21
22  fn create(data: Arc<RwLock<Vec<f32>>>, shape: Vec<usize>, stride: Vec<usize>) -> Self;
23
24  fn compute_strides(shape: &Vec<usize>) -> Vec<usize>;
25}
26
27pub trait DeviceStorageCreation : DeviceStorage {
28  fn zeros(shape: Vec<usize>, device: Option<Device>, requires_grad: Option<bool>) -> Self;
29  fn ones(shape: Vec<usize>, device: Option<Device>, requires_grad: Option<bool>) -> Self;
30  fn from_ndarray<S, D, T>(data: &ArrayBase<S, D>, device: Option<Device>, requires_grad: Option<bool>) -> Self
31  where 
32    S: ndarray::Data<Elem = T>,
33    T: AsPrimitive<f32>,
34    D: Dimension;
35
36  fn uniform(l_bound: f32, r_bound: f32, shape: Vec<usize>, device: Option<Device>, requires_grad: Option<bool>) -> Self;
37}
38
39
40pub trait DeviceStorage  {
41  fn view(&self, new_shape: Vec<usize>) -> Self where Self: Sized;
42
43  fn data(&self) -> Arc<RwLock<Vec<f32>>>;
44
45  fn data_mut(&self) -> std::sync::RwLockWriteGuard<Vec<f32>>;
46
47  fn set_data(&mut self, data: Vec<f32>);
48
49  fn shape(&self) -> &Vec<usize>;
50
51  fn set_shape(&mut self, shape: Vec<usize>);
52
53  fn stride(&self) -> &Vec<usize>;
54
55  fn set_stride(&mut self, stride: Vec<usize>);
56
57  fn offset(&self) -> usize;
58
59  fn get(&self, indices: &[usize]) -> f32;
60
61  fn set(&mut self, indices: &[usize], value: f32);
62
63  fn make_contiguous(&self) -> (Vec<f32>, i32);
64
65  fn is_contiguous(&self) -> bool;
66}
67
68
69#[derive(Clone)]
70pub enum Storage {
71  Cpu(CpuStorage),
72}