1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4use log::{info, warn};
5use shared_memory::Shmem;
6use tokio::sync::Mutex;
7use mechutil::shm::ShmContext;
8
9#[derive(Debug, Clone, Copy)]
11pub struct ShmPtr(pub usize);
12
13impl ShmPtr {
14 pub fn new<T>(ptr: *mut T) -> Self {
15 Self(ptr as usize)
16 }
17
18 pub fn as_ptr<T>(&self) -> *mut T {
19 self.0 as *mut T
20 }
21}
22
23pub struct ThreadSafeShmem(pub Shmem);
25
26unsafe impl Send for ThreadSafeShmem {}
27unsafe impl Sync for ThreadSafeShmem {}
28
29impl std::fmt::Debug for ThreadSafeShmem {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_tuple("ThreadSafeShmem")
32 .field(&self.0.get_os_id())
33 .finish()
34 }
35}
36
37impl std::ops::Deref for ThreadSafeShmem {
38 type Target = Shmem;
39 fn deref(&self) -> &Self::Target {
40 &self.0
41 }
42}
43
44#[derive(Clone)]
46pub struct SendShmContext(pub Arc<Mutex<Option<Arc<ShmContext>>>>);
47
48unsafe impl Send for SendShmContext {}
49unsafe impl Sync for SendShmContext {}
50
51#[derive(Debug, Clone)]
53pub struct ShmMap {
54 pointers: Arc<HashMap<String, ShmPtr>>,
55 _context: Option<Arc<ThreadSafeShmem>>,
57}
58
59impl ShmMap {
60 pub fn new(pointers: HashMap<String, ShmPtr>) -> Self {
61 Self {
62 pointers: Arc::new(pointers),
63 _context: None,
64 }
65 }
66
67 pub fn get<T>(&self, name: &str) -> Option<*mut T> {
68 self.pointers.get(name).map(|p| p.as_ptr())
69 }
70
71 pub fn get_raw(&self, name: &str) -> Option<usize> {
72 self.pointers.get(name).map(|p| p.0)
73 }
74
75 pub unsafe fn read<T: Copy>(&self, name: &str) -> Option<T> {
78 self.get::<T>(name).map(|ptr| *ptr)
79 }
80
81 pub unsafe fn write<T: Copy>(&self, name: &str, value: T) -> bool {
84 if let Some(ptr) = self.get::<T>(name) {
85 *ptr = value;
86 true
87 } else {
88 false
89 }
90 }
91}
92
93pub async fn resolve_shm_pointers(
99 ctx_handle: SendShmContext,
100 names: Vec<String>,
101 timeout_ms: u64
102) -> Option<ShmMap> {
103 tokio::task::spawn_blocking(move || {
104 let ctx = ctx_handle;
105 let shm_ctx = &ctx.0;
106 let mut attempts = 0;
107 let max_attempts = timeout_ms / 100;
108
109 loop {
110 if let Ok(guard) = shm_ctx.try_lock() {
112 if let Some(ref ctx) = *guard {
113 info!("SHM context received, resolving {} pointers...", names.len());
114 let mut pointers = HashMap::new();
115 let mut mapped = 0;
116
117 for name in &names {
118 if let Some(ptr) = unsafe { ctx.get_pointer(name) } {
119 pointers.insert(name.clone(), ShmPtr::new(ptr));
120 mapped += 1;
121 }
122 }
123 info!("SHM pointers resolved: {}/{}", mapped, names.len());
124 return Some(ShmMap::new(pointers));
125 }
126 }
127
128 attempts += 1;
129 if attempts > max_attempts {
130 warn!("Timed out waiting for SHM context after {}ms", timeout_ms);
131 return None;
132 }
133
134 std::thread::sleep(Duration::from_millis(100));
135 }
136 }).await.unwrap_or(None)
137}