Skip to main content

agent_sim/shared/
mod.rs

1use crate::sim::types::{SimSharedSlot, SimSharedSlotRaw};
2use memmap2::MmapMut;
3use std::fs::OpenOptions;
4use std::path::Path;
5use std::sync::atomic::{AtomicU64, Ordering};
6
7const WRITER_NAME_LEN: usize = 64;
8const MAX_SNAPSHOT_SPINS: usize = 32;
9
10#[repr(C)]
11#[derive(Clone, Copy)]
12struct SharedHeader {
13    generation: u64,
14    slot_count: u32,
15    initialized: u32,
16    writer_session: [u8; WRITER_NAME_LEN],
17}
18
19pub struct SharedRegion {
20    mmap: MmapMut,
21    slot_count: usize,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum SharedSnapshotError {
26    Busy { attempts: usize },
27    Uninitialized,
28    Invalid(String),
29}
30
31impl std::fmt::Display for SharedSnapshotError {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        match self {
34            Self::Busy { attempts } => {
35                write!(
36                    f,
37                    "shared snapshot remained unstable after {attempts} read attempts"
38                )
39            }
40            Self::Uninitialized => write!(f, "shared snapshot has not been published yet"),
41            Self::Invalid(message) => write!(f, "{message}"),
42        }
43    }
44}
45
46impl std::error::Error for SharedSnapshotError {}
47
48impl SharedRegion {
49    pub fn open(
50        path: &Path,
51        slot_count: usize,
52        writer_session: &str,
53        initialize: bool,
54    ) -> Result<Self, String> {
55        if let Some(parent) = path.parent() {
56            std::fs::create_dir_all(parent).map_err(|e| {
57                format!(
58                    "failed to create shared region parent '{}': {e}",
59                    parent.display()
60                )
61            })?;
62        }
63
64        let expected_len = Self::byte_len(slot_count);
65        if !initialize && !path.exists() {
66            return Err(format!(
67                "shared region '{}' has not been initialized by its writer yet",
68                path.display()
69            ));
70        }
71        let file = OpenOptions::new()
72            .create(initialize)
73            .truncate(false)
74            .read(true)
75            .write(true)
76            .open(path)
77            .map_err(|e| format!("failed to open shared region '{}': {e}", path.display()))?;
78        let current_len = file
79            .metadata()
80            .map_err(|e| format!("failed to inspect shared region '{}': {e}", path.display()))?
81            .len() as usize;
82        let mut should_initialize = initialize;
83        if current_len == 0 {
84            if !initialize {
85                return Err(format!(
86                    "shared region '{}' has not been initialized by its writer yet",
87                    path.display()
88                ));
89            }
90            file.set_len(expected_len as u64).map_err(|e| {
91                format!(
92                    "failed to size shared region '{}' to {} bytes: {e}",
93                    path.display(),
94                    expected_len
95                )
96            })?;
97            should_initialize = true;
98        } else if current_len != expected_len {
99            return Err(format!(
100                "shared region '{}' has size {} but expected {}",
101                path.display(),
102                current_len,
103                expected_len
104            ));
105        }
106
107        let mut mmap = unsafe {
108            MmapMut::map_mut(&file)
109                .map_err(|e| format!("failed to mmap shared region '{}': {e}", path.display()))?
110        };
111        if should_initialize {
112            let header = SharedHeader {
113                generation: 0,
114                slot_count: slot_count as u32,
115                initialized: 0,
116                writer_session: encode_writer(writer_session),
117            };
118            Self::write_header(&mut mmap, &header);
119            let offset = std::mem::size_of::<SharedHeader>();
120            mmap[offset..].fill(0);
121        } else {
122            let header = Self::read_header(&mmap);
123            if header.slot_count as usize != slot_count {
124                return Err(format!(
125                    "shared region '{}' slot count mismatch: region={} expected={}",
126                    path.display(),
127                    header.slot_count,
128                    slot_count
129                ));
130            }
131        }
132        Ok(Self { mmap, slot_count })
133    }
134
135    pub fn publish(&mut self, slots: &[SimSharedSlot]) -> Result<(), String> {
136        if slots.len() != self.slot_count {
137            return Err(format!(
138                "attempted to publish {} slots into region with capacity {}",
139                slots.len(),
140                self.slot_count
141            ));
142        }
143        for (expected_slot_id, slot) in slots.iter().enumerate() {
144            if slot.slot_id as usize != expected_slot_id {
145                return Err(format!(
146                    "shared slot id {} appeared at dense index {}; expected slot id {}",
147                    slot.slot_id, expected_slot_id, expected_slot_id
148                ));
149            }
150        }
151        let generation = self.generation();
152        self.set_generation(generation.wrapping_add(1)); // odd = write in progress
153        {
154            let slot_storage = self.slot_storage_mut();
155            for (idx, slot) in slots.iter().enumerate() {
156                slot_storage[idx] = slot.to_raw();
157            }
158        }
159        self.set_initialized(true);
160        self.set_generation(generation.wrapping_add(2)); // even = stable snapshot
161        self.mmap
162            .flush_async()
163            .map_err(|e| format!("failed flushing shared snapshot: {e}"))?;
164        Ok(())
165    }
166
167    pub fn read_snapshot(&self) -> Result<Vec<SimSharedSlot>, SharedSnapshotError> {
168        for _ in 0..MAX_SNAPSHOT_SPINS {
169            let before = self.generation();
170            if !before.is_multiple_of(2) {
171                std::hint::spin_loop();
172                continue;
173            }
174            if !self.is_initialized() {
175                return Err(SharedSnapshotError::Uninitialized);
176            }
177            let snapshot = self
178                .slot_storage()
179                .iter()
180                .enumerate()
181                .map(|(expected_slot_id, slot)| {
182                    let decoded = SimSharedSlot::try_from_raw(*slot).map_err(|err| {
183                        SharedSnapshotError::Invalid(format!(
184                            "shared snapshot slot {expected_slot_id} is invalid: {err}"
185                        ))
186                    })?;
187                    if decoded.slot_id as usize != expected_slot_id {
188                        return Err(SharedSnapshotError::Invalid(format!(
189                            "shared snapshot slot {expected_slot_id} reported slot id {}",
190                            decoded.slot_id
191                        )));
192                    }
193                    Ok(decoded)
194                })
195                .collect::<Result<Vec<_>, _>>()?;
196            let after = self.generation();
197            if before == after && after.is_multiple_of(2) {
198                return Ok(snapshot);
199            }
200            std::hint::spin_loop();
201        }
202        Err(SharedSnapshotError::Busy {
203            attempts: MAX_SNAPSHOT_SPINS,
204        })
205    }
206
207    fn byte_len(slot_count: usize) -> usize {
208        std::mem::size_of::<SharedHeader>() + (slot_count * std::mem::size_of::<SimSharedSlotRaw>())
209    }
210
211    fn read_header(mmap: &MmapMut) -> SharedHeader {
212        let header_ptr = mmap.as_ptr().cast::<SharedHeader>();
213        unsafe { *header_ptr }
214    }
215
216    fn write_header(mmap: &mut MmapMut, header: &SharedHeader) {
217        let header_ptr = mmap.as_mut_ptr().cast::<SharedHeader>();
218        unsafe {
219            *header_ptr = *header;
220        }
221    }
222
223    fn generation(&self) -> u64 {
224        let header = self.mmap.as_ptr().cast::<SharedHeader>();
225        let generation_ptr = unsafe { std::ptr::addr_of!((*header).generation) as *mut u64 };
226        // SAFETY:
227        // - `generation_ptr` points to the `generation` field inside the mmap-backed
228        //   `SharedHeader`, which is a valid, initialized `u64` for the lifetime of `self`.
229        // - `SharedHeader` is `#[repr(C)]`, so the field address is stable.
230        // - access to this field is performed atomically via `generation()`/`set_generation()`
231        //   once the region is initialized, satisfying `AtomicU64::from_ptr` requirements.
232        let generation = unsafe { AtomicU64::from_ptr(generation_ptr) };
233        generation.load(Ordering::Acquire)
234    }
235
236    fn set_generation(&mut self, value: u64) {
237        let header = self.mmap.as_mut_ptr().cast::<SharedHeader>();
238        let generation_ptr = unsafe { std::ptr::addr_of_mut!((*header).generation) };
239        let generation = unsafe { AtomicU64::from_ptr(generation_ptr) };
240        generation.store(value, Ordering::Release);
241    }
242
243    fn is_initialized(&self) -> bool {
244        let header = Self::read_header(&self.mmap);
245        header.initialized != 0
246    }
247
248    fn set_initialized(&mut self, initialized: bool) {
249        let header = self.mmap.as_mut_ptr().cast::<SharedHeader>();
250        unsafe {
251            (*header).initialized = u32::from(initialized);
252        }
253    }
254
255    fn slot_storage(&self) -> &[SimSharedSlotRaw] {
256        let offset = std::mem::size_of::<SharedHeader>();
257        let ptr = unsafe { self.mmap.as_ptr().add(offset).cast::<SimSharedSlotRaw>() };
258        unsafe { std::slice::from_raw_parts(ptr, self.slot_count) }
259    }
260
261    fn slot_storage_mut(&mut self) -> &mut [SimSharedSlotRaw] {
262        let offset = std::mem::size_of::<SharedHeader>();
263        let ptr = unsafe {
264            self.mmap
265                .as_mut_ptr()
266                .add(offset)
267                .cast::<SimSharedSlotRaw>()
268        };
269        unsafe { std::slice::from_raw_parts_mut(ptr, self.slot_count) }
270    }
271}
272
273fn encode_writer(writer_session: &str) -> [u8; WRITER_NAME_LEN] {
274    let mut out = [0_u8; WRITER_NAME_LEN];
275    let bytes = writer_session.as_bytes();
276    let len = bytes.len().min(WRITER_NAME_LEN.saturating_sub(1));
277    out[..len].copy_from_slice(&bytes[..len]);
278    out
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use crate::sim::types::SignalValue;
285
286    #[test]
287    fn shared_region_roundtrip_snapshot() {
288        let dir = tempfile::tempdir().expect("tempdir should be creatable");
289        let path = dir.path().join("region.bin");
290        let mut region = SharedRegion::open(&path, 2, "writer", true)
291            .expect("shared region should open for writer");
292        region
293            .publish(&[
294                SimSharedSlot {
295                    slot_id: 0,
296                    value: SignalValue::F32(12.5),
297                },
298                SimSharedSlot {
299                    slot_id: 1,
300                    value: SignalValue::Bool(true),
301                },
302            ])
303            .expect("publish should succeed");
304
305        let reader =
306            SharedRegion::open(&path, 2, "writer", false).expect("reader should open region");
307        let snapshot = reader
308            .read_snapshot()
309            .expect("snapshot should be consistent");
310        assert_eq!(snapshot.len(), 2);
311        assert!(snapshot.iter().any(|slot| slot.slot_id == 0));
312        assert!(snapshot.iter().any(|slot| slot.slot_id == 1));
313    }
314
315    #[test]
316    fn publish_invalid_slot_id_does_not_poison_generation() {
317        let dir = tempfile::tempdir().expect("tempdir should be creatable");
318        let path = dir.path().join("region.bin");
319        let mut region = SharedRegion::open(&path, 2, "writer", true)
320            .expect("shared region should open for writer");
321        region
322            .publish(&[
323                SimSharedSlot {
324                    slot_id: 0,
325                    value: SignalValue::F32(7.0),
326                },
327                SimSharedSlot {
328                    slot_id: 1,
329                    value: SignalValue::Bool(false),
330                },
331            ])
332            .expect("initial publish should succeed");
333        let before = region.generation();
334
335        let err = region.publish(&[
336            SimSharedSlot {
337                slot_id: 0,
338                value: SignalValue::F32(7.0),
339            },
340            SimSharedSlot {
341                slot_id: 9,
342                value: SignalValue::Bool(true),
343            },
344        ]);
345        assert!(err.is_err(), "publish should fail for invalid slot id");
346        assert_eq!(
347            region.generation(),
348            before,
349            "failed publish must not leave generation in a poisoned state"
350        );
351        assert!(
352            region.generation().is_multiple_of(2),
353            "generation must remain even after failed publish"
354        );
355        let snapshot = region
356            .read_snapshot()
357            .expect("snapshot should remain readable after failed publish");
358        assert!(
359            snapshot
360                .iter()
361                .any(|slot| slot.slot_id == 0 && slot.value == SignalValue::F32(7.0)),
362            "previous snapshot payload should remain readable after failed publish"
363        );
364    }
365
366    #[test]
367    fn publish_wraps_generation_without_leaving_odd_state() {
368        let dir = tempfile::tempdir().expect("tempdir should be creatable");
369        let path = dir.path().join("region.bin");
370        let mut region = SharedRegion::open(&path, 2, "writer", true)
371            .expect("shared region should open for writer");
372        region.set_generation(u64::MAX - 1);
373        region.set_initialized(true);
374
375        region
376            .publish(&[
377                SimSharedSlot {
378                    slot_id: 0,
379                    value: SignalValue::F32(0.0),
380                },
381                SimSharedSlot {
382                    slot_id: 1,
383                    value: SignalValue::Bool(true),
384                },
385            ])
386            .expect("publish should succeed near generation rollover");
387
388        let generation = region.generation();
389        assert_eq!(generation, 0, "generation should wrap to 0 after publish");
390        assert!(
391            generation.is_multiple_of(2),
392            "generation must remain even after wrapped publish"
393        );
394        let snapshot = region
395            .read_snapshot()
396            .expect("snapshot should remain readable after wrapped publish");
397        assert!(
398            snapshot
399                .iter()
400                .any(|slot| slot.slot_id == 1 && slot.value == SignalValue::Bool(true)),
401            "snapshot payload should remain readable after wrapped publish"
402        );
403    }
404
405    #[test]
406    fn read_snapshot_fails_when_writer_never_finishes() {
407        let dir = tempfile::tempdir().expect("tempdir should be creatable");
408        let path = dir.path().join("region.bin");
409        let mut region = SharedRegion::open(&path, 2, "writer", true)
410            .expect("shared region should open for writer");
411        region.set_generation(1);
412
413        let err = region
414            .read_snapshot()
415            .expect_err("reader should refuse unstable snapshot");
416        assert_eq!(
417            err,
418            SharedSnapshotError::Busy {
419                attempts: MAX_SNAPSHOT_SPINS
420            }
421        );
422    }
423
424    #[test]
425    fn writer_reinitialization_clears_previous_snapshot() {
426        let dir = tempfile::tempdir().expect("tempdir should be creatable");
427        let path = dir.path().join("region.bin");
428        let mut writer =
429            SharedRegion::open(&path, 2, "writer", true).expect("writer should open shared region");
430        writer
431            .publish(&[
432                SimSharedSlot {
433                    slot_id: 0,
434                    value: SignalValue::F32(9.5),
435                },
436                SimSharedSlot {
437                    slot_id: 1,
438                    value: SignalValue::Bool(false),
439                },
440            ])
441            .expect("publish should succeed");
442
443        let reopened = SharedRegion::open(&path, 2, "writer", true)
444            .expect("reinitialized writer should reopen shared region");
445        let err = reopened
446            .read_snapshot()
447            .expect_err("reinitialized writer should clear any previously published snapshot");
448        assert_eq!(
449            err,
450            SharedSnapshotError::Uninitialized,
451            "reinitializing a writer should leave the region unpublished until the writer publishes again"
452        );
453    }
454}