Skip to main content

maolan_plugin_protocol/
shm.rs

1#[cfg(unix)]
2use std::ffi::CString;
3
4/// Owned mapping of a shared-memory segment.
5pub struct ShmMapping {
6    ptr: *mut u8,
7    size: usize,
8    #[allow(dead_code)]
9    name: String,
10    #[cfg(windows)]
11    handle: *mut std::ffi::c_void,
12}
13
14// Safety: ShmMapping is Send+Sync because the mapped memory is process-shared.
15unsafe impl Send for ShmMapping {}
16unsafe impl Sync for ShmMapping {}
17
18impl ShmMapping {
19    /// Create a new shared-memory segment, truncate to `size`, and map it.
20    #[cfg(unix)]
21    pub fn create(name: &str, size: usize) -> Result<Self, String> {
22        let c_name = CString::new(name).map_err(|e| e.to_string())?;
23        let fd = unsafe { libc::shm_open(c_name.as_ptr(), libc::O_CREAT | libc::O_RDWR, 0o644) };
24        if fd < 0 {
25            return Err(format!(
26                "shm_open({}, O_CREAT|O_RDWR) failed: {:?}",
27                name,
28                std::io::Error::last_os_error()
29            ));
30        }
31        if unsafe { libc::ftruncate(fd, size as libc::off_t) } < 0 {
32            unsafe { libc::close(fd) };
33            return Err(format!(
34                "ftruncate failed: {:?}",
35                std::io::Error::last_os_error()
36            ));
37        }
38        let ptr = unsafe {
39            libc::mmap(
40                std::ptr::null_mut(),
41                size,
42                libc::PROT_READ | libc::PROT_WRITE,
43                libc::MAP_SHARED,
44                fd,
45                0,
46            )
47        };
48        unsafe { libc::close(fd) };
49        if ptr == libc::MAP_FAILED {
50            unsafe { libc::shm_unlink(c_name.as_ptr()) };
51            return Err(format!(
52                "mmap failed: {:?}",
53                std::io::Error::last_os_error()
54            ));
55        }
56        Ok(Self {
57            ptr: ptr as *mut u8,
58            size,
59            name: name.to_string(),
60        })
61    }
62
63    /// Open an existing shared-memory segment and map it.
64    #[cfg(unix)]
65    pub fn open_existing(name: &str, size: usize) -> Result<Self, String> {
66        let c_name = CString::new(name).map_err(|e| e.to_string())?;
67        let fd = unsafe { libc::shm_open(c_name.as_ptr(), libc::O_RDWR, 0) };
68        if fd < 0 {
69            return Err(format!(
70                "shm_open({}, O_RDWR) failed: {:?}",
71                name,
72                std::io::Error::last_os_error()
73            ));
74        }
75        let ptr = unsafe {
76            libc::mmap(
77                std::ptr::null_mut(),
78                size,
79                libc::PROT_READ | libc::PROT_WRITE,
80                libc::MAP_SHARED,
81                fd,
82                0,
83            )
84        };
85        unsafe { libc::close(fd) };
86        if ptr == libc::MAP_FAILED {
87            return Err(format!(
88                "mmap failed: {:?}",
89                std::io::Error::last_os_error()
90            ));
91        }
92        Ok(Self {
93            ptr: ptr as *mut u8,
94            size,
95            name: name.to_string(),
96        })
97    }
98
99    /// Create a new pagefile-backed shared-memory segment on Windows.
100    #[cfg(windows)]
101    pub fn create(name: &str, size: usize) -> Result<Self, String> {
102        use windows_sys::Win32::Foundation::{GetLastError, INVALID_HANDLE_VALUE};
103        use windows_sys::Win32::System::Memory::{
104            CreateFileMappingW, FILE_MAP_ALL_ACCESS, MapViewOfFile, PAGE_READWRITE,
105        };
106
107        let wide_name: Vec<u16> = format!("Local\\{}", name)
108            .encode_utf16()
109            .chain(std::iter::once(0))
110            .collect();
111        let handle = unsafe {
112            CreateFileMappingW(
113                INVALID_HANDLE_VALUE,
114                std::ptr::null_mut(),
115                PAGE_READWRITE,
116                0,
117                size as u32,
118                wide_name.as_ptr(),
119            )
120        };
121        if handle.is_null() {
122            return Err(format!("CreateFileMappingW failed: {}", unsafe {
123                GetLastError()
124            }));
125        }
126        let ptr = unsafe { MapViewOfFile(handle, FILE_MAP_ALL_ACCESS, 0, 0, size) };
127        if ptr.Value.is_null() {
128            unsafe { windows_sys::Win32::Foundation::CloseHandle(handle) };
129            return Err(format!("MapViewOfFile failed: {}", unsafe {
130                GetLastError()
131            }));
132        }
133        Ok(Self {
134            ptr: ptr.Value as *mut u8,
135            size,
136            name: name.to_string(),
137            handle,
138        })
139    }
140
141    /// Open an existing shared-memory segment on Windows.
142    #[cfg(windows)]
143    pub fn open_existing(name: &str, size: usize) -> Result<Self, String> {
144        use windows_sys::Win32::Foundation::{CloseHandle, GetLastError};
145        use windows_sys::Win32::System::Memory::{
146            FILE_MAP_ALL_ACCESS, MapViewOfFile, OpenFileMappingW,
147        };
148
149        let wide_name: Vec<u16> = format!("Local\\{}", name)
150            .encode_utf16()
151            .chain(std::iter::once(0))
152            .collect();
153        let handle = unsafe { OpenFileMappingW(FILE_MAP_ALL_ACCESS, 0, wide_name.as_ptr()) };
154        if handle.is_null() {
155            return Err(format!("OpenFileMappingW failed: {}", unsafe {
156                GetLastError()
157            }));
158        }
159        let ptr = unsafe { MapViewOfFile(handle, FILE_MAP_ALL_ACCESS, 0, 0, size) };
160        if ptr.Value.is_null() {
161            unsafe { CloseHandle(handle) };
162            return Err(format!("MapViewOfFile failed: {}", unsafe {
163                GetLastError()
164            }));
165        }
166        Ok(Self {
167            ptr: ptr.Value as *mut u8,
168            size,
169            name: name.to_string(),
170            handle,
171        })
172    }
173
174    /// Raw pointer to the start of the mapping.
175    pub fn as_ptr(&self) -> *mut u8 {
176        self.ptr
177    }
178
179    /// Size of the mapping in bytes.
180    pub fn size(&self) -> usize {
181        self.size
182    }
183
184    /// Name used to create/open the segment.
185    pub fn name(&self) -> &str {
186        &self.name
187    }
188
189    /// Unlink the underlying POSIX shared-memory object.
190    #[cfg(unix)]
191    pub fn unlink(name: &str) -> Result<(), String> {
192        let c_name = CString::new(name).map_err(|e| e.to_string())?;
193        let res = unsafe { libc::shm_unlink(c_name.as_ptr()) };
194        if res < 0 {
195            Err(format!(
196                "shm_unlink failed: {:?}",
197                std::io::Error::last_os_error()
198            ))
199        } else {
200            Ok(())
201        }
202    }
203
204    #[cfg(windows)]
205    pub fn unlink(_name: &str) -> Result<(), String> {
206        Ok(())
207    }
208}
209
210#[cfg(unix)]
211impl Drop for ShmMapping {
212    fn drop(&mut self) {
213        if !self.ptr.is_null() && self.ptr != libc::MAP_FAILED as *mut u8 {
214            unsafe { libc::munmap(self.ptr as *mut libc::c_void, self.size) };
215        }
216    }
217}
218
219#[cfg(windows)]
220impl Drop for ShmMapping {
221    fn drop(&mut self) {
222        use windows_sys::Win32::Foundation::CloseHandle;
223        use windows_sys::Win32::System::Memory::UnmapViewOfFile;
224        if !self.ptr.is_null() {
225            unsafe {
226                UnmapViewOfFile(
227                    windows_sys::Win32::System::Memory::MEMORY_MAPPED_VIEW_ADDRESS {
228                        Value: self.ptr as *mut std::ffi::c_void,
229                    },
230                )
231            };
232        }
233        if !self.handle.is_null() {
234            unsafe { CloseHandle(self.handle) };
235        }
236    }
237}