Skip to main content

autocore_std/
shm.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4use log::{info, warn};
5use shared_memory::Shmem;
6use tokio::sync::Mutex;
7use mechutil::shm::ShmContext;
8
9/// Wrapper for Shared Memory pointer (stored as address).
10#[derive(Debug, Clone, Copy)]
11pub struct ShmPtr(pub usize);
12
13impl ShmPtr {
14    pub fn new<T>(ptr: *mut T) -> Self {
15        Self(ptr as usize)
16    }
17
18    pub fn as_ptr<T>(&self) -> *mut T {
19        self.0 as *mut T
20    }
21}
22
23/// Wrapper for Shmem to implement Send + Sync
24pub struct ThreadSafeShmem(pub Shmem);
25
26unsafe impl Send for ThreadSafeShmem {}
27unsafe impl Sync for ThreadSafeShmem {}
28
29impl std::fmt::Debug for ThreadSafeShmem {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_tuple("ThreadSafeShmem")
32            .field(&self.0.get_os_id())
33            .finish()
34    }
35}
36
37impl std::ops::Deref for ThreadSafeShmem {
38    type Target = Shmem;
39    fn deref(&self) -> &Self::Target {
40        &self.0
41    }
42}
43
44/// Send+Sync wrapper for mechutil's SHM context handle.
45#[derive(Clone)]
46pub struct SendShmContext(pub Arc<Mutex<Option<Arc<ShmContext>>>>);
47
48unsafe impl Send for SendShmContext {}
49unsafe impl Sync for SendShmContext {}
50
51/// A generic store for mapping names to Shared Memory pointers.
52#[derive(Debug, Clone)]
53pub struct ShmMap {
54    pointers: Arc<HashMap<String, ShmPtr>>,
55    // We keep the shm context alive if needed, or at least a reference
56    _context: Option<Arc<ThreadSafeShmem>>, 
57}
58
59impl ShmMap {
60    pub fn new(pointers: HashMap<String, ShmPtr>) -> Self {
61        Self {
62            pointers: Arc::new(pointers),
63            _context: None,
64        }
65    }
66
67    pub fn get<T>(&self, name: &str) -> Option<*mut T> {
68        self.pointers.get(name).map(|p| p.as_ptr())
69    }
70    
71    pub fn get_raw(&self, name: &str) -> Option<usize> {
72        self.pointers.get(name).map(|p| p.0)
73    }
74
75    /// Read a value from shared memory. 
76    /// SAFETY: Caller must ensure T matches the type of data at the pointer.
77    pub unsafe fn read<T: Copy>(&self, name: &str) -> Option<T> {
78        self.get::<T>(name).map(|ptr| *ptr)
79    }
80
81    /// Write a value to shared memory.
82    /// SAFETY: Caller must ensure T matches the type of data at the pointer.
83    pub unsafe fn write<T: Copy>(&self, name: &str, value: T) -> bool {
84        if let Some(ptr) = self.get::<T>(name) {
85            *ptr = value;
86            true
87        } else {
88            false
89        }
90    }
91}
92
93/// Helper to wait for SHM context and resolve pointers.
94/// 
95/// * `ctx_handle`: The SendShmContext wrapper.
96/// * `names`: List of variable names to resolve.
97/// * `timeout_ms`: Max time to wait for configuration.
98pub async fn resolve_shm_pointers(
99    ctx_handle: SendShmContext,
100    names: Vec<String>,
101    timeout_ms: u64
102) -> Option<ShmMap> {
103    tokio::task::spawn_blocking(move || {
104        let ctx = ctx_handle;
105        let shm_ctx = &ctx.0;
106        let mut attempts = 0;
107        let max_attempts = timeout_ms / 100;
108
109        loop {
110            // Use try_lock to avoid needing async lock in blocking task
111            if let Ok(guard) = shm_ctx.try_lock() {
112                if let Some(ref ctx) = *guard {
113                    info!("SHM context received, resolving {} pointers...", names.len());
114                    let mut pointers = HashMap::new();
115                    let mut mapped = 0;
116                    
117                    for name in &names {
118                        if let Some(ptr) = unsafe { ctx.get_pointer(name) } {
119                            pointers.insert(name.clone(), ShmPtr::new(ptr));
120                            mapped += 1;
121                        }
122                    }
123                    info!("SHM pointers resolved: {}/{}", mapped, names.len());
124                    return Some(ShmMap::new(pointers));
125                }
126            }
127            
128            attempts += 1;
129            if attempts > max_attempts {
130                warn!("Timed out waiting for SHM context after {}ms", timeout_ms);
131                return None;
132            }
133            
134            std::thread::sleep(Duration::from_millis(100));
135        }
136    }).await.unwrap_or(None)
137}