Skip to main content

radiate_utils/buff/
value.rs

1use crate::{Shape, Strides};
2#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4use std::fmt::Debug;
5use std::sync::Arc;
6
7#[derive(PartialEq, Eq, Hash, Clone)]
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9pub enum Value<T> {
10    Scalar(T),
11    Array {
12        values: Arc<[T]>,
13        shape: Shape,
14        strides: Strides,
15    },
16}
17
18impl<T> Value<T> {
19    pub fn shape(&self) -> Option<&Shape> {
20        match self {
21            Value::Array { shape, .. } => Some(shape),
22            _ => None,
23        }
24    }
25
26    pub fn strides(&self) -> Option<&[usize]> {
27        match self {
28            Value::Scalar(_) => None,
29            Value::Array { strides, .. } => Some(&strides.as_slice()),
30        }
31    }
32
33    pub fn as_scalar(&self) -> Option<&T> {
34        match self {
35            Value::Scalar(value) => Some(value),
36            _ => None,
37        }
38    }
39
40    pub fn as_array(&self) -> Option<&[T]> {
41        match self {
42            Value::Array { values, .. } => Some(values),
43            _ => None,
44        }
45    }
46}
47
48impl<S, T, F> From<(S, F)> for Value<T>
49where
50    S: Into<Shape>,
51    F: FnMut(usize) -> T,
52{
53    fn from(value: (S, F)) -> Self {
54        let (shape, mut f) = value;
55        let dims = shape.into();
56
57        let mut strides = vec![1; dims.rank()];
58        for i in (0..dims.rank() - 1).rev() {
59            strides[i] = strides[i + 1] * dims.dim_at(i + 1);
60        }
61
62        let size = dims.size();
63        let mut values = Vec::with_capacity(size);
64        for index in 0..size {
65            values.push(f(index));
66        }
67
68        Value::Array {
69            values: Arc::from(values),
70            shape: dims.clone(),
71            strides: Strides::from(strides),
72        }
73    }
74}
75
76impl<T: Debug> Debug for Value<T> {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        match self {
79            Value::Scalar(value) => write!(f, "Scalar({:?})", value),
80            Value::Array { shape, strides, .. } => {
81                write!(f, "Arr(shape={:?}, strides={:?})", shape, strides)
82            }
83        }
84    }
85}