radiate_utils/buff/
value.rs1use crate::{Shape, Strides};
2#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4use std::fmt::Debug;
5use std::sync::Arc;
6
7#[derive(PartialEq, Eq, Hash, Clone)]
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9pub enum Value<T> {
10 Scalar(T),
11 Array {
12 values: Arc<[T]>,
13 shape: Shape,
14 strides: Strides,
15 },
16}
17
18impl<T> Value<T> {
19 pub fn shape(&self) -> Option<&Shape> {
20 match self {
21 Value::Array { shape, .. } => Some(shape),
22 _ => None,
23 }
24 }
25
26 pub fn strides(&self) -> Option<&[usize]> {
27 match self {
28 Value::Scalar(_) => None,
29 Value::Array { strides, .. } => Some(&strides.as_slice()),
30 }
31 }
32
33 pub fn as_scalar(&self) -> Option<&T> {
34 match self {
35 Value::Scalar(value) => Some(value),
36 _ => None,
37 }
38 }
39
40 pub fn as_array(&self) -> Option<&[T]> {
41 match self {
42 Value::Array { values, .. } => Some(values),
43 _ => None,
44 }
45 }
46}
47
48impl<S, T, F> From<(S, F)> for Value<T>
49where
50 S: Into<Shape>,
51 F: FnMut(usize) -> T,
52{
53 fn from(value: (S, F)) -> Self {
54 let (shape, mut f) = value;
55 let dims = shape.into();
56
57 let mut strides = vec![1; dims.rank()];
58 for i in (0..dims.rank() - 1).rev() {
59 strides[i] = strides[i + 1] * dims.dim_at(i + 1);
60 }
61
62 let size = dims.size();
63 let mut values = Vec::with_capacity(size);
64 for index in 0..size {
65 values.push(f(index));
66 }
67
68 Value::Array {
69 values: Arc::from(values),
70 shape: dims.clone(),
71 strides: Strides::from(strides),
72 }
73 }
74}
75
76impl<T: Debug> Debug for Value<T> {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 match self {
79 Value::Scalar(value) => write!(f, "Scalar({:?})", value),
80 Value::Array { shape, strides, .. } => {
81 write!(f, "Arr(shape={:?}, strides={:?})", shape, strides)
82 }
83 }
84 }
85}