Skip to main content

maolan_plugin_protocol/
shm.rs

1#[cfg(unix)]
2use std::ffi::CString;
3
4pub struct ShmMapping {
5    ptr: *mut u8,
6    size: usize,
7    #[allow(dead_code)]
8    name: String,
9    #[cfg(windows)]
10    handle: *mut std::ffi::c_void,
11}
12
13unsafe impl Send for ShmMapping {}
14unsafe impl Sync for ShmMapping {}
15
16impl ShmMapping {
17    #[cfg(unix)]
18    pub fn create(name: &str, size: usize) -> Result<Self, String> {
19        let c_name = CString::new(name).map_err(|e| e.to_string())?;
20        let fd = unsafe { libc::shm_open(c_name.as_ptr(), libc::O_CREAT | libc::O_RDWR, 0o644) };
21        if fd < 0 {
22            return Err(format!(
23                "shm_open({}, O_CREAT|O_RDWR) failed: {:?}",
24                name,
25                std::io::Error::last_os_error()
26            ));
27        }
28        if unsafe { libc::ftruncate(fd, size as libc::off_t) } < 0 {
29            unsafe { libc::close(fd) };
30            return Err(format!(
31                "ftruncate failed: {:?}",
32                std::io::Error::last_os_error()
33            ));
34        }
35        let ptr = unsafe {
36            libc::mmap(
37                std::ptr::null_mut(),
38                size,
39                libc::PROT_READ | libc::PROT_WRITE,
40                libc::MAP_SHARED,
41                fd,
42                0,
43            )
44        };
45        unsafe { libc::close(fd) };
46        if ptr == libc::MAP_FAILED {
47            unsafe { libc::shm_unlink(c_name.as_ptr()) };
48            return Err(format!(
49                "mmap failed: {:?}",
50                std::io::Error::last_os_error()
51            ));
52        }
53        Ok(Self {
54            ptr: ptr as *mut u8,
55            size,
56            name: name.to_string(),
57        })
58    }
59
60    #[cfg(unix)]
61    pub fn open_existing(name: &str, size: usize) -> Result<Self, String> {
62        let c_name = CString::new(name).map_err(|e| e.to_string())?;
63        let fd = unsafe { libc::shm_open(c_name.as_ptr(), libc::O_RDWR, 0) };
64        if fd < 0 {
65            return Err(format!(
66                "shm_open({}, O_RDWR) failed: {:?}",
67                name,
68                std::io::Error::last_os_error()
69            ));
70        }
71        let ptr = unsafe {
72            libc::mmap(
73                std::ptr::null_mut(),
74                size,
75                libc::PROT_READ | libc::PROT_WRITE,
76                libc::MAP_SHARED,
77                fd,
78                0,
79            )
80        };
81        unsafe { libc::close(fd) };
82        if ptr == libc::MAP_FAILED {
83            return Err(format!(
84                "mmap failed: {:?}",
85                std::io::Error::last_os_error()
86            ));
87        }
88        Ok(Self {
89            ptr: ptr as *mut u8,
90            size,
91            name: name.to_string(),
92        })
93    }
94
95    #[cfg(windows)]
96    pub fn create(name: &str, size: usize) -> Result<Self, String> {
97        use windows_sys::Win32::Foundation::{GetLastError, INVALID_HANDLE_VALUE};
98        use windows_sys::Win32::System::Memory::{
99            CreateFileMappingW, FILE_MAP_ALL_ACCESS, MapViewOfFile, PAGE_READWRITE,
100        };
101
102        let wide_name: Vec<u16> = format!("Local\\{}", name)
103            .encode_utf16()
104            .chain(std::iter::once(0))
105            .collect();
106        let handle = unsafe {
107            CreateFileMappingW(
108                INVALID_HANDLE_VALUE,
109                std::ptr::null_mut(),
110                PAGE_READWRITE,
111                0,
112                size as u32,
113                wide_name.as_ptr(),
114            )
115        };
116        if handle.is_null() {
117            return Err(format!("CreateFileMappingW failed: {}", unsafe {
118                GetLastError()
119            }));
120        }
121        let ptr = unsafe { MapViewOfFile(handle, FILE_MAP_ALL_ACCESS, 0, 0, size) };
122        if ptr.Value.is_null() {
123            unsafe { windows_sys::Win32::Foundation::CloseHandle(handle) };
124            return Err(format!("MapViewOfFile failed: {}", unsafe {
125                GetLastError()
126            }));
127        }
128        Ok(Self {
129            ptr: ptr.Value as *mut u8,
130            size,
131            name: name.to_string(),
132            handle,
133        })
134    }
135
136    #[cfg(windows)]
137    pub fn open_existing(name: &str, size: usize) -> Result<Self, String> {
138        use windows_sys::Win32::Foundation::{CloseHandle, GetLastError};
139        use windows_sys::Win32::System::Memory::{
140            FILE_MAP_ALL_ACCESS, MapViewOfFile, OpenFileMappingW,
141        };
142
143        let wide_name: Vec<u16> = format!("Local\\{}", name)
144            .encode_utf16()
145            .chain(std::iter::once(0))
146            .collect();
147        let handle = unsafe { OpenFileMappingW(FILE_MAP_ALL_ACCESS, 0, wide_name.as_ptr()) };
148        if handle.is_null() {
149            return Err(format!("OpenFileMappingW failed: {}", unsafe {
150                GetLastError()
151            }));
152        }
153        let ptr = unsafe { MapViewOfFile(handle, FILE_MAP_ALL_ACCESS, 0, 0, size) };
154        if ptr.Value.is_null() {
155            unsafe { CloseHandle(handle) };
156            return Err(format!("MapViewOfFile failed: {}", unsafe {
157                GetLastError()
158            }));
159        }
160        Ok(Self {
161            ptr: ptr.Value as *mut u8,
162            size,
163            name: name.to_string(),
164            handle,
165        })
166    }
167
168    pub fn as_ptr(&self) -> *mut u8 {
169        self.ptr
170    }
171
172    pub fn size(&self) -> usize {
173        self.size
174    }
175
176    pub fn name(&self) -> &str {
177        &self.name
178    }
179
180    #[cfg(unix)]
181    pub fn unlink(name: &str) -> Result<(), String> {
182        let c_name = CString::new(name).map_err(|e| e.to_string())?;
183        let res = unsafe { libc::shm_unlink(c_name.as_ptr()) };
184        if res < 0 {
185            Err(format!(
186                "shm_unlink failed: {:?}",
187                std::io::Error::last_os_error()
188            ))
189        } else {
190            Ok(())
191        }
192    }
193
194    #[cfg(windows)]
195    pub fn unlink(_name: &str) -> Result<(), String> {
196        Ok(())
197    }
198}
199
200#[cfg(unix)]
201impl Drop for ShmMapping {
202    fn drop(&mut self) {
203        if !self.ptr.is_null() && self.ptr != libc::MAP_FAILED as *mut u8 {
204            unsafe { libc::munmap(self.ptr as *mut libc::c_void, self.size) };
205        }
206    }
207}
208
209#[cfg(windows)]
210impl Drop for ShmMapping {
211    fn drop(&mut self) {
212        use windows_sys::Win32::Foundation::CloseHandle;
213        use windows_sys::Win32::System::Memory::UnmapViewOfFile;
214        if !self.ptr.is_null() {
215            unsafe {
216                UnmapViewOfFile(
217                    windows_sys::Win32::System::Memory::MEMORY_MAPPED_VIEW_ADDRESS {
218                        Value: self.ptr as *mut std::ffi::c_void,
219                    },
220                )
221            };
222        }
223        if !self.handle.is_null() {
224            unsafe { CloseHandle(self.handle) };
225        }
226    }
227}