use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use log::{info, warn};
use shared_memory::Shmem;
use tokio::sync::Mutex;
use mechutil::shm::ShmContext;
#[derive(Debug, Clone, Copy)]
pub struct ShmPtr(pub usize);
impl ShmPtr {
pub fn new<T>(ptr: *mut T) -> Self {
Self(ptr as usize)
}
pub fn as_ptr<T>(&self) -> *mut T {
self.0 as *mut T
}
}
pub struct ThreadSafeShmem(pub Shmem);
unsafe impl Send for ThreadSafeShmem {}
unsafe impl Sync for ThreadSafeShmem {}
impl std::fmt::Debug for ThreadSafeShmem {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("ThreadSafeShmem")
.field(&self.0.get_os_id())
.finish()
}
}
impl std::ops::Deref for ThreadSafeShmem {
type Target = Shmem;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Clone)]
pub struct SendShmContext(pub Arc<Mutex<Option<Arc<ShmContext>>>>);
unsafe impl Send for SendShmContext {}
unsafe impl Sync for SendShmContext {}
#[derive(Debug, Clone)]
pub struct ShmMap {
pointers: Arc<HashMap<String, ShmPtr>>,
_context: Option<Arc<ThreadSafeShmem>>,
}
impl ShmMap {
pub fn new(pointers: HashMap<String, ShmPtr>) -> Self {
Self {
pointers: Arc::new(pointers),
_context: None,
}
}
pub fn get<T>(&self, name: &str) -> Option<*mut T> {
self.pointers.get(name).map(|p| p.as_ptr())
}
pub fn get_raw(&self, name: &str) -> Option<usize> {
self.pointers.get(name).map(|p| p.0)
}
pub unsafe fn read<T: Copy>(&self, name: &str) -> Option<T> {
self.get::<T>(name).map(|ptr| *ptr)
}
pub unsafe fn write<T: Copy>(&self, name: &str, value: T) -> bool {
if let Some(ptr) = self.get::<T>(name) {
*ptr = value;
true
} else {
false
}
}
}
pub async fn resolve_shm_pointers(
ctx_handle: SendShmContext,
names: Vec<String>,
timeout_ms: u64
) -> Option<ShmMap> {
tokio::task::spawn_blocking(move || {
let ctx = ctx_handle;
let shm_ctx = &ctx.0;
let mut attempts = 0;
let max_attempts = timeout_ms / 100;
loop {
if let Ok(guard) = shm_ctx.try_lock() {
if let Some(ref ctx) = *guard {
info!("SHM context received, resolving {} pointers...", names.len());
let mut pointers = HashMap::new();
let mut mapped = 0;
for name in &names {
if let Some(ptr) = unsafe { ctx.get_pointer(name) } {
pointers.insert(name.clone(), ShmPtr::new(ptr));
mapped += 1;
}
}
info!("SHM pointers resolved: {}/{}", mapped, names.len());
return Some(ShmMap::new(pointers));
}
}
attempts += 1;
if attempts > max_attempts {
warn!("Timed out waiting for SHM context after {}ms", timeout_ms);
return None;
}
std::thread::sleep(Duration::from_millis(100));
}
}).await.unwrap_or(None)
}