nodedb_types/
vector_dtype.rs1#[repr(u8)]
24#[derive(
25 Debug,
26 Clone,
27 Copy,
28 Default,
29 PartialEq,
30 Eq,
31 Hash,
32 serde::Serialize,
33 serde::Deserialize,
34 zerompk::ToMessagePack,
35 zerompk::FromMessagePack,
36)]
37#[msgpack(c_enum)]
38#[non_exhaustive]
39pub enum VectorStorageDtype {
40 #[default]
42 F32 = 0,
43 F16 = 1,
46 BF16 = 2,
50}
51
52impl VectorStorageDtype {
53 pub const fn bytes_per_dim(self) -> usize {
55 match self {
56 Self::F32 => 4,
57 Self::F16 => 2,
58 Self::BF16 => 2,
59 }
60 }
61
62 pub const fn bytes_for_dim(self, dim: usize) -> usize {
64 dim * self.bytes_per_dim()
65 }
66
67 pub const fn as_str(self) -> &'static str {
70 match self {
71 Self::F32 => "f32",
72 Self::F16 => "f16",
73 Self::BF16 => "bf16",
74 }
75 }
76
77 pub fn parse(s: &str) -> Option<Self> {
81 match s {
82 "f32" => Some(Self::F32),
83 "f16" => Some(Self::F16),
84 "bf16" => Some(Self::BF16),
85 _ => None,
86 }
87 }
88}
89
90impl core::str::FromStr for VectorStorageDtype {
91 type Err = ();
92
93 fn from_str(s: &str) -> Result<Self, Self::Err> {
94 Self::parse(s).ok_or(())
95 }
96}
97
98impl core::fmt::Display for VectorStorageDtype {
99 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
100 f.write_str(self.as_str())
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 #[test]
109 fn default_is_f32() {
110 assert_eq!(VectorStorageDtype::default(), VectorStorageDtype::F32);
111 }
112
113 #[test]
114 fn bytes_per_dim_matches_iec_widths() {
115 assert_eq!(VectorStorageDtype::F32.bytes_per_dim(), 4);
116 assert_eq!(VectorStorageDtype::F16.bytes_per_dim(), 2);
117 assert_eq!(VectorStorageDtype::BF16.bytes_per_dim(), 2);
118 }
119
120 #[test]
121 fn bytes_for_dim_is_dim_times_width() {
122 assert_eq!(VectorStorageDtype::F32.bytes_for_dim(128), 512);
123 assert_eq!(VectorStorageDtype::BF16.bytes_for_dim(1536), 3072);
124 assert_eq!(VectorStorageDtype::F16.bytes_for_dim(256), 512);
125 }
126
127 #[test]
128 fn as_str_roundtrips_from_str() {
129 for v in [
130 VectorStorageDtype::F32,
131 VectorStorageDtype::F16,
132 VectorStorageDtype::BF16,
133 ] {
134 assert_eq!(VectorStorageDtype::parse(v.as_str()), Some(v));
135 }
136 }
137
138 #[test]
139 fn from_str_unknown_returns_none() {
140 assert_eq!(VectorStorageDtype::parse("fp8"), None);
141 assert_eq!(VectorStorageDtype::parse("F32"), None);
142 assert_eq!(VectorStorageDtype::parse(""), None);
143 }
144
145 #[test]
146 fn display_matches_as_str() {
147 for v in [
148 VectorStorageDtype::F32,
149 VectorStorageDtype::F16,
150 VectorStorageDtype::BF16,
151 ] {
152 assert_eq!(format!("{}", v), v.as_str());
153 }
154 }
155
156 #[test]
157 fn msgpack_roundtrip() {
158 for v in [
159 VectorStorageDtype::F32,
160 VectorStorageDtype::F16,
161 VectorStorageDtype::BF16,
162 ] {
163 let bytes = zerompk::to_msgpack_vec(&v).unwrap();
164 let restored: VectorStorageDtype = zerompk::from_msgpack(&bytes).unwrap();
165 assert_eq!(restored, v);
166 }
167 }
168}