Skip to main content

wave_runtime/
memory.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! Device memory buffer types for the WAVE runtime.
5//!
6//! Provides CPU-side buffer management for kernel arguments. Buffers are
7//! serialized to temporary files for subprocess-based kernel launch in v1.
8//! Direct GPU memory mapping is planned for v2.
9
10use crate::error::RuntimeError;
11use std::fmt;
12
13/// Element data type for buffer contents.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ElementType {
16    /// 32-bit float.
17    F32,
18    /// 32-bit unsigned integer.
19    U32,
20    /// 32-bit signed integer.
21    I32,
22    /// 16-bit float.
23    F16,
24    /// 64-bit float.
25    F64,
26}
27
28impl ElementType {
29    /// Size of one element in bytes.
30    #[must_use]
31    pub fn size_bytes(self) -> usize {
32        match self {
33            Self::F32 | Self::U32 | Self::I32 => 4,
34            Self::F16 => 2,
35            Self::F64 => 8,
36        }
37    }
38}
39
40impl fmt::Display for ElementType {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        match self {
43            Self::F32 => write!(f, "f32"),
44            Self::U32 => write!(f, "u32"),
45            Self::I32 => write!(f, "i32"),
46            Self::F16 => write!(f, "f16"),
47            Self::F64 => write!(f, "f64"),
48        }
49    }
50}
51
52/// CPU-side buffer representing device memory.
53#[derive(Debug, Clone)]
54pub struct DeviceBuffer {
55    /// Raw bytes of the buffer.
56    pub data: Vec<u8>,
57    /// Number of elements in the buffer.
58    pub count: usize,
59    /// Element data type.
60    pub element_type: ElementType,
61}
62
63impl DeviceBuffer {
64    /// Create a buffer from an `f32` slice.
65    #[must_use]
66    pub fn from_f32(data: &[f32]) -> Self {
67        let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
68        Self {
69            data: bytes,
70            count: data.len(),
71            element_type: ElementType::F32,
72        }
73    }
74
75    /// Create a zero-filled `f32` buffer.
76    #[must_use]
77    pub fn zeros_f32(count: usize) -> Self {
78        Self {
79            data: vec![0u8; count * 4],
80            count,
81            element_type: ElementType::F32,
82        }
83    }
84
85    /// Create a buffer from a `u32` slice.
86    #[must_use]
87    pub fn from_u32(data: &[u32]) -> Self {
88        let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
89        Self {
90            data: bytes,
91            count: data.len(),
92            element_type: ElementType::U32,
93        }
94    }
95
96    /// Create a zero-filled `u32` buffer.
97    #[must_use]
98    pub fn zeros_u32(count: usize) -> Self {
99        Self {
100            data: vec![0u8; count * 4],
101            count,
102            element_type: ElementType::U32,
103        }
104    }
105
106    /// Create a buffer from an `i32` slice.
107    #[must_use]
108    pub fn from_i32(data: &[i32]) -> Self {
109        let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
110        Self {
111            data: bytes,
112            count: data.len(),
113            element_type: ElementType::I32,
114        }
115    }
116
117    /// Read buffer contents as `f32` values.
118    ///
119    /// # Errors
120    ///
121    /// Returns `RuntimeError::Memory` if the buffer element type is not `f32`.
122    pub fn to_f32(&self) -> Result<Vec<f32>, RuntimeError> {
123        if self.element_type != ElementType::F32 {
124            return Err(RuntimeError::Memory(format!(
125                "cannot read {} buffer as f32",
126                self.element_type
127            )));
128        }
129        Ok(self
130            .data
131            .chunks_exact(4)
132            .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
133            .collect())
134    }
135
136    /// Read buffer contents as `u32` values.
137    ///
138    /// # Errors
139    ///
140    /// Returns `RuntimeError::Memory` if the buffer element type is not `u32`.
141    pub fn to_u32(&self) -> Result<Vec<u32>, RuntimeError> {
142        if self.element_type != ElementType::U32 {
143            return Err(RuntimeError::Memory(format!(
144                "cannot read {} buffer as u32",
145                self.element_type
146            )));
147        }
148        Ok(self
149            .data
150            .chunks_exact(4)
151            .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
152            .collect())
153    }
154
155    /// Total size of the buffer in bytes.
156    #[must_use]
157    pub fn size_bytes(&self) -> usize {
158        self.data.len()
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    #[test]
167    fn test_from_f32_roundtrip() {
168        let data = vec![1.0_f32, 2.0, 3.0, 4.0];
169        let buf = DeviceBuffer::from_f32(&data);
170        assert_eq!(buf.count, 4);
171        assert_eq!(buf.element_type, ElementType::F32);
172        assert_eq!(buf.to_f32().unwrap(), data);
173    }
174
175    #[test]
176    fn test_zeros_f32() {
177        let buf = DeviceBuffer::zeros_f32(8);
178        assert_eq!(buf.count, 8);
179        assert_eq!(buf.size_bytes(), 32);
180        assert_eq!(buf.to_f32().unwrap(), vec![0.0; 8]);
181    }
182
183    #[test]
184    fn test_from_u32_roundtrip() {
185        let data = vec![10_u32, 20, 30];
186        let buf = DeviceBuffer::from_u32(&data);
187        assert_eq!(buf.to_u32().unwrap(), data);
188    }
189
190    #[test]
191    fn test_type_mismatch() {
192        let buf = DeviceBuffer::from_u32(&[1, 2]);
193        assert!(buf.to_f32().is_err());
194    }
195}