Skip to main content

dynamo_memory/
actions.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Storage actions.
5
6use super::{MemoryDescriptor, StorageError};
7
8/// Extension trait for storage types that support memory setting operations
9pub trait Memset: MemoryDescriptor {
10    /// Sets a region of memory to a specific value
11    ///
12    /// # Arguments
13    /// * `value` - The value to set (will be truncated to u8)
14    /// * `offset` - Offset in bytes from the start of the storage
15    /// * `size` - Number of bytes to set
16    ///
17    /// # Safety
18    /// The caller must ensure:
19    /// - offset + size <= self.size()
20    /// - No other references exist to the memory region being set
21    fn memset(&mut self, value: u8, offset: usize, size: usize) -> Result<(), StorageError>;
22}
23
24/// Extension trait for storage types that support slicing operations
25pub trait Slice: MemoryDescriptor + 'static {
26    /// Returns an immutable byte slice view of the entire storage region
27    ///
28    /// # Safety
29    /// This is an unsafe method. The caller must ensure:
30    /// - The memory region remains valid for the lifetime of the returned slice
31    /// - The memory region is properly initialized
32    /// - No concurrent mutable access occurs while the slice is in use
33    /// - The memory backing this storage remains valid (implementors with owned
34    ///   memory satisfy this, but care must be taken with unowned memory regions)
35    unsafe fn as_slice(&self) -> Result<&[u8], StorageError>;
36
37    /// Returns an immutable byte slice view of a subregion
38    ///
39    /// # Arguments
40    /// * `offset` - Offset in bytes from the start of the storage
41    /// * `len` - Number of bytes to slice
42    ///
43    /// # Safety
44    /// The caller must ensure:
45    /// - offset + len <= self.size()
46    /// - The memory region is valid and initialized
47    /// - No concurrent mutable access occurs while the slice is in use
48    fn slice(&self, offset: usize, len: usize) -> Result<&[u8], StorageError> {
49        // SAFETY: Caller guarantees memory validity per trait's safety contract
50        let slice = unsafe { self.as_slice()? };
51
52        // validate offset and len
53        if offset.saturating_add(len) > slice.len() {
54            return Err(StorageError::Unsupported("slice out of bounds".into()));
55        }
56
57        slice
58            .get(offset..offset.saturating_add(len))
59            .ok_or_else(|| StorageError::Unsupported("slice out of bounds".into()))
60    }
61
62    /// Returns a typed immutable slice view of the entire storage region
63    ///
64    /// # Safety
65    /// The caller must ensure:
66    /// - The memory region is valid and initialized
67    /// - The memory is properly aligned for type T
68    /// - The size is a multiple of `size_of::<T>()`
69    /// - No concurrent mutable access occurs while the slice is in use
70    /// - The data represents valid values of type T
71    fn as_slice_typed<T: Sized>(&self) -> Result<&[T], StorageError> {
72        // SAFETY: Caller guarantees memory validity per trait's safety contract
73        let bytes = unsafe { self.as_slice()? };
74        let ptr = bytes.as_ptr() as *const T;
75        let elem_size = std::mem::size_of::<T>();
76        if elem_size == 0 {
77            return Err(StorageError::Unsupported(
78                "zero-sized types are not supported".into(),
79            ));
80        }
81        let len = bytes.len() / elem_size;
82
83        if !(bytes.as_ptr() as usize).is_multiple_of(std::mem::align_of::<T>()) {
84            return Err(StorageError::Unsupported(format!(
85                "memory not aligned for type (required alignment: {})",
86                std::mem::align_of::<T>()
87            )));
88        }
89
90        if bytes.len() % elem_size != 0 {
91            return Err(StorageError::Unsupported(format!(
92                "size {} is not a multiple of type size {}",
93                bytes.len(),
94                elem_size
95            )));
96        }
97
98        // SAFETY: Caller guarantees memory is valid, aligned, and properly initialized for T
99        Ok(unsafe { std::slice::from_raw_parts(ptr, len) })
100    }
101
102    /// Returns a typed immutable slice view of a subregion
103    ///
104    /// # Arguments
105    /// * `offset` - Offset in bytes from the start of the storage
106    /// * `len` - Number of elements of type T to slice
107    ///
108    /// # Safety
109    /// The caller must ensure:
110    /// - offset + (len * size_of::<T>()) <= self.size()
111    /// - offset is properly aligned for type T
112    /// - The memory region is valid and initialized
113    /// - No concurrent mutable access occurs while the slice is in use
114    /// - The data represents valid values of type T
115    fn slice_typed<T: Sized>(&self, offset: usize, len: usize) -> Result<&[T], StorageError> {
116        let type_size = std::mem::size_of::<T>();
117        let byte_len = len
118            .checked_mul(type_size)
119            .ok_or_else(|| StorageError::Unsupported("length overflow".into()))?;
120
121        let bytes = self.slice(offset, byte_len)?;
122        let ptr = bytes.as_ptr() as *const T;
123
124        if !(bytes.as_ptr() as usize).is_multiple_of(std::mem::align_of::<T>()) {
125            return Err(StorageError::Unsupported(format!(
126                "memory not aligned for type (required alignment: {})",
127                std::mem::align_of::<T>()
128            )));
129        }
130
131        // SAFETY: Caller guarantees memory is valid, aligned, and properly initialized for T
132        Ok(unsafe { std::slice::from_raw_parts(ptr, len) })
133    }
134}
135
136/// Extension trait for storage types that support mutable slicing operations.
137pub trait SliceMut: MemoryDescriptor + 'static {
138    /// Returns a mutable byte slice view of the entire storage region
139    ///
140    /// # Safety
141    /// This is an unsafe method. The caller must ensure:
142    /// - The memory region remains valid for the lifetime of the returned slice
143    /// - The memory region is valid and accessible
144    /// - No other references (mutable or immutable) exist to this memory region
145    /// - The memory backing this storage remains valid (implementors with owned
146    ///   memory satisfy this, but care must be taken with unowned memory regions)
147    unsafe fn as_slice_mut(&mut self) -> Result<&mut [u8], StorageError>;
148
149    /// Returns a mutable byte slice view of a subregion
150    ///
151    /// # Arguments
152    /// * `offset` - Offset in bytes from the start of the storage
153    /// * `len` - Number of bytes to slice
154    ///
155    /// # Safety
156    /// The caller must ensure:
157    /// - offset + len <= self.size()
158    /// - The memory region is valid
159    /// - No other references (mutable or immutable) exist to this memory region
160    fn slice_mut(&mut self, offset: usize, len: usize) -> Result<&mut [u8], StorageError> {
161        // SAFETY: Caller guarantees memory validity per trait's safety contract
162        let slice = unsafe { self.as_slice_mut()? };
163
164        // validate offset and len
165        if offset.saturating_add(len) > slice.len() {
166            return Err(StorageError::Unsupported("slice out of bounds".into()));
167        }
168
169        slice
170            .get_mut(offset..offset.saturating_add(len))
171            .ok_or_else(|| StorageError::Unsupported("slice out of bounds".into()))
172    }
173
174    /// Returns a typed mutable slice view of the entire storage region
175    ///
176    /// # Safety
177    /// The caller must ensure:
178    /// - The memory region is valid
179    /// - The memory is properly aligned for type T
180    /// - The size is a multiple of `size_of::<T>()`
181    /// - No other references (mutable or immutable) exist to this memory region
182    fn as_slice_typed_mut<T: Sized>(&mut self) -> Result<&mut [T], StorageError> {
183        // SAFETY: Caller guarantees memory validity per trait's safety contract
184        let bytes = unsafe { self.as_slice_mut()? };
185        let ptr = bytes.as_mut_ptr() as *mut T;
186        let len = bytes.len() / std::mem::size_of::<T>();
187
188        if !(bytes.as_ptr() as usize).is_multiple_of(std::mem::align_of::<T>()) {
189            return Err(StorageError::Unsupported(format!(
190                "memory not aligned for type (required alignment: {})",
191                std::mem::align_of::<T>()
192            )));
193        }
194
195        if bytes.len() % std::mem::size_of::<T>() != 0 {
196            return Err(StorageError::Unsupported(format!(
197                "size {} is not a multiple of type size {}",
198                bytes.len(),
199                std::mem::size_of::<T>()
200            )));
201        }
202
203        // SAFETY: Caller guarantees memory is valid, aligned, and no aliasing
204        Ok(unsafe { std::slice::from_raw_parts_mut(ptr, len) })
205    }
206
207    /// Returns a typed mutable slice view of a subregion
208    ///
209    /// # Arguments
210    /// * `offset` - Offset in bytes from the start of the storage
211    /// * `len` - Number of elements of type T to slice
212    ///
213    /// # Safety
214    /// The caller must ensure:
215    /// - offset + (len * size_of::<T>()) <= self.size()
216    /// - offset is properly aligned for type T
217    /// - The memory region is valid
218    /// - No other references (mutable or immutable) exist to this memory region
219    fn slice_typed_mut<T: Sized>(
220        &mut self,
221        offset: usize,
222        len: usize,
223    ) -> Result<&mut [T], StorageError> {
224        let type_size = std::mem::size_of::<T>();
225        let byte_len = len
226            .checked_mul(type_size)
227            .ok_or_else(|| StorageError::Unsupported("length overflow".into()))?;
228
229        let bytes = self.slice_mut(offset, byte_len)?;
230        let ptr = bytes.as_mut_ptr() as *mut T;
231
232        if !(bytes.as_ptr() as usize).is_multiple_of(std::mem::align_of::<T>()) {
233            return Err(StorageError::Unsupported(format!(
234                "memory not aligned for type (required alignment: {})",
235                std::mem::align_of::<T>()
236            )));
237        }
238
239        // SAFETY: Caller guarantees memory is valid, aligned, and no aliasing
240        Ok(unsafe { std::slice::from_raw_parts_mut(ptr, len) })
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use crate::SystemStorage;
248
249    // Helper to create a test storage
250    fn create_storage(size: usize) -> SystemStorage {
251        SystemStorage::new(size).expect("allocation failed")
252    }
253
254    // ========== Memset tests ==========
255
256    #[test]
257    fn test_memset_full_region() {
258        let mut storage = create_storage(1024);
259        storage
260            .memset(0xAB, 0, 1024)
261            .expect("memset should succeed");
262
263        let slice = unsafe { storage.as_slice().expect("as_slice should succeed") };
264        assert!(slice.iter().all(|&b| b == 0xAB));
265    }
266
267    #[test]
268    fn test_memset_partial_region() {
269        let mut storage = create_storage(1024);
270        // First fill with 0x00
271        storage
272            .memset(0x00, 0, 1024)
273            .expect("memset should succeed");
274        // Then fill middle region with 0xFF
275        storage
276            .memset(0xFF, 100, 200)
277            .expect("memset should succeed");
278
279        let slice = unsafe { storage.as_slice().expect("as_slice should succeed") };
280        // Check before region
281        assert!(slice[..100].iter().all(|&b| b == 0x00));
282        // Check filled region
283        assert!(slice[100..300].iter().all(|&b| b == 0xFF));
284        // Check after region
285        assert!(slice[300..].iter().all(|&b| b == 0x00));
286    }
287
288    #[test]
289    fn test_memset_at_end() {
290        let mut storage = create_storage(1024);
291        // Fill the last 100 bytes
292        storage
293            .memset(0x42, 924, 100)
294            .expect("memset should succeed");
295
296        let slice = unsafe { storage.as_slice().expect("as_slice should succeed") };
297        assert!(slice[924..].iter().all(|&b| b == 0x42));
298    }
299
300    #[test]
301    fn test_memset_zero_size() {
302        let mut storage = create_storage(1024);
303        // Zero-size memset should succeed (no-op)
304        storage
305            .memset(0xFF, 500, 0)
306            .expect("zero-size memset should succeed");
307    }
308
309    #[test]
310    fn test_memset_out_of_bounds() {
311        let mut storage = create_storage(1024);
312        // Try to write beyond the storage
313        let result = storage.memset(0xFF, 900, 200);
314        assert!(result.is_err());
315    }
316
317    #[test]
318    fn test_memset_offset_overflow() {
319        let mut storage = create_storage(1024);
320        // offset + size would overflow
321        let result = storage.memset(0xFF, usize::MAX, 1);
322        assert!(result.is_err());
323    }
324
325    // ========== Slice tests ==========
326
327    #[test]
328    fn test_as_slice_full() {
329        let mut storage = create_storage(1024);
330        storage
331            .memset(0xCD, 0, 1024)
332            .expect("memset should succeed");
333
334        let slice = unsafe { storage.as_slice().expect("as_slice should succeed") };
335        assert_eq!(slice.len(), 1024);
336        assert!(slice.iter().all(|&b| b == 0xCD));
337    }
338
339    #[test]
340    fn test_slice_partial() {
341        let mut storage = create_storage(1024);
342        storage
343            .memset(0x00, 0, 1024)
344            .expect("memset should succeed");
345        storage
346            .memset(0xAA, 100, 50)
347            .expect("memset should succeed");
348
349        let partial = storage.slice(100, 50).expect("slice should succeed");
350        assert_eq!(partial.len(), 50);
351        assert!(partial.iter().all(|&b| b == 0xAA));
352    }
353
354    #[test]
355    fn test_slice_at_start() {
356        let storage = create_storage(1024);
357        let slice = storage.slice(0, 100).expect("slice should succeed");
358        assert_eq!(slice.len(), 100);
359    }
360
361    #[test]
362    fn test_slice_at_end() {
363        let storage = create_storage(1024);
364        let slice = storage.slice(924, 100).expect("slice should succeed");
365        assert_eq!(slice.len(), 100);
366    }
367
368    #[test]
369    fn test_slice_zero_length() {
370        let storage = create_storage(1024);
371        let slice = storage
372            .slice(500, 0)
373            .expect("zero-length slice should succeed");
374        assert!(slice.is_empty());
375    }
376
377    #[test]
378    fn test_slice_out_of_bounds() {
379        let storage = create_storage(1024);
380        let result = storage.slice(900, 200);
381        assert!(result.is_err());
382    }
383
384    #[test]
385    fn test_slice_offset_overflow() {
386        let storage = create_storage(1024);
387        // offset + len would overflow when using saturating_add
388        let result = storage.slice(usize::MAX, 1);
389        assert!(result.is_err());
390    }
391
392    // ========== Typed slice tests ==========
393
394    #[test]
395    fn test_as_slice_typed_u32() {
396        let mut storage = create_storage(1024);
397        // Fill with known pattern
398        storage
399            .memset(0x00, 0, 1024)
400            .expect("memset should succeed");
401
402        let typed: &[u32] = storage
403            .as_slice_typed()
404            .expect("typed slice should succeed");
405        assert_eq!(typed.len(), 256); // 1024 / 4
406        assert!(typed.iter().all(|&v| v == 0));
407    }
408
409    #[test]
410    fn test_as_slice_typed_u64() {
411        let storage = create_storage(1024);
412        let typed: &[u64] = storage
413            .as_slice_typed()
414            .expect("typed slice should succeed");
415        assert_eq!(typed.len(), 128); // 1024 / 8
416    }
417
418    #[test]
419    fn test_slice_typed_partial() {
420        let mut storage = create_storage(1024);
421        storage
422            .memset(0x00, 0, 1024)
423            .expect("memset should succeed");
424
425        // Slice 10 u32 elements starting at offset 0
426        let typed: &[u32] = storage
427            .slice_typed(0, 10)
428            .expect("typed slice should succeed");
429        assert_eq!(typed.len(), 10);
430    }
431
432    #[test]
433    fn test_slice_typed_with_offset() {
434        let storage = create_storage(1024);
435        // Slice starting at offset 64 (aligned for u64)
436        let typed: &[u64] = storage
437            .slice_typed(64, 5)
438            .expect("typed slice should succeed");
439        assert_eq!(typed.len(), 5);
440    }
441
442    #[test]
443    fn test_as_slice_typed_zst_error() {
444        let storage = create_storage(1024);
445        // Zero-sized types should fail
446        let result: Result<&[()], _> = storage.as_slice_typed();
447        assert!(result.is_err());
448    }
449
450    #[test]
451    fn test_as_slice_typed_size_not_multiple() {
452        // Create storage with size not divisible by 4
453        let storage = create_storage(1023);
454        let result: Result<&[u32], _> = storage.as_slice_typed();
455        assert!(result.is_err());
456    }
457
458    #[test]
459    fn test_slice_typed_length_overflow() {
460        let storage = create_storage(1024);
461        // len * size_of::<u64>() would overflow
462        let result: Result<&[u64], _> = storage.slice_typed(0, usize::MAX);
463        assert!(result.is_err());
464    }
465
466    #[test]
467    fn test_slice_typed_out_of_bounds() {
468        let storage = create_storage(1024);
469        // Request more elements than available
470        let result: Result<&[u64], _> = storage.slice_typed(0, 200);
471        assert!(result.is_err());
472    }
473}