flowtorch_core/
device.rs

1use crate::{storage::Storage, DType};
2
3#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
4pub enum Device {
5    Cpu,
6}
7
8impl Device {
9    pub fn zeros(&self, shape: &[usize], dtype: DType) -> Storage {
10        match self {
11            Device::Cpu => {
12                let elem_count: usize = shape.iter().product();
13                let buffer: Vec<u8> = match dtype {
14                    DType::F32 => {
15                        let data = vec![0f32; elem_count];
16                        data.iter().flat_map(|&x| x.to_le_bytes()).collect()
17                    }
18                    DType::F64 => {
19                        let data = vec![0f64; elem_count];
20                        data.iter().flat_map(|&x| x.to_le_bytes()).collect()
21                    }
22                };
23                //let buffer = vec![0; elem_count * dtype.size_in_bytes()];
24                Storage::Cpu { dtype, buffer }
25            }
26        }
27    }
28}