1use crate::error::RuntimeError;
11use std::fmt;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ElementType {
16 F32,
18 U32,
20 I32,
22 F16,
24 F64,
26}
27
28impl ElementType {
29 #[must_use]
31 pub fn size_bytes(self) -> usize {
32 match self {
33 Self::F32 | Self::U32 | Self::I32 => 4,
34 Self::F16 => 2,
35 Self::F64 => 8,
36 }
37 }
38}
39
40impl fmt::Display for ElementType {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 match self {
43 Self::F32 => write!(f, "f32"),
44 Self::U32 => write!(f, "u32"),
45 Self::I32 => write!(f, "i32"),
46 Self::F16 => write!(f, "f16"),
47 Self::F64 => write!(f, "f64"),
48 }
49 }
50}
51
52#[derive(Debug, Clone)]
54pub struct DeviceBuffer {
55 pub data: Vec<u8>,
57 pub count: usize,
59 pub element_type: ElementType,
61}
62
63impl DeviceBuffer {
64 #[must_use]
66 pub fn from_f32(data: &[f32]) -> Self {
67 let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
68 Self {
69 data: bytes,
70 count: data.len(),
71 element_type: ElementType::F32,
72 }
73 }
74
75 #[must_use]
77 pub fn zeros_f32(count: usize) -> Self {
78 Self {
79 data: vec![0u8; count * 4],
80 count,
81 element_type: ElementType::F32,
82 }
83 }
84
85 #[must_use]
87 pub fn from_u32(data: &[u32]) -> Self {
88 let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
89 Self {
90 data: bytes,
91 count: data.len(),
92 element_type: ElementType::U32,
93 }
94 }
95
96 #[must_use]
98 pub fn zeros_u32(count: usize) -> Self {
99 Self {
100 data: vec![0u8; count * 4],
101 count,
102 element_type: ElementType::U32,
103 }
104 }
105
106 #[must_use]
108 pub fn from_i32(data: &[i32]) -> Self {
109 let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
110 Self {
111 data: bytes,
112 count: data.len(),
113 element_type: ElementType::I32,
114 }
115 }
116
117 pub fn to_f32(&self) -> Result<Vec<f32>, RuntimeError> {
123 if self.element_type != ElementType::F32 {
124 return Err(RuntimeError::Memory(format!(
125 "cannot read {} buffer as f32",
126 self.element_type
127 )));
128 }
129 Ok(self
130 .data
131 .chunks_exact(4)
132 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
133 .collect())
134 }
135
136 pub fn to_u32(&self) -> Result<Vec<u32>, RuntimeError> {
142 if self.element_type != ElementType::U32 {
143 return Err(RuntimeError::Memory(format!(
144 "cannot read {} buffer as u32",
145 self.element_type
146 )));
147 }
148 Ok(self
149 .data
150 .chunks_exact(4)
151 .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
152 .collect())
153 }
154
155 #[must_use]
157 pub fn size_bytes(&self) -> usize {
158 self.data.len()
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 #[test]
167 fn test_from_f32_roundtrip() {
168 let data = vec![1.0_f32, 2.0, 3.0, 4.0];
169 let buf = DeviceBuffer::from_f32(&data);
170 assert_eq!(buf.count, 4);
171 assert_eq!(buf.element_type, ElementType::F32);
172 assert_eq!(buf.to_f32().unwrap(), data);
173 }
174
175 #[test]
176 fn test_zeros_f32() {
177 let buf = DeviceBuffer::zeros_f32(8);
178 assert_eq!(buf.count, 8);
179 assert_eq!(buf.size_bytes(), 32);
180 assert_eq!(buf.to_f32().unwrap(), vec![0.0; 8]);
181 }
182
183 #[test]
184 fn test_from_u32_roundtrip() {
185 let data = vec![10_u32, 20, 30];
186 let buf = DeviceBuffer::from_u32(&data);
187 assert_eq!(buf.to_u32().unwrap(), data);
188 }
189
190 #[test]
191 fn test_type_mismatch() {
192 let buf = DeviceBuffer::from_u32(&[1, 2]);
193 assert!(buf.to_f32().is_err());
194 }
195}