1use std::fs::File;
4use std::io;
5use std::os::fd::AsFd;
6use std::sync::atomic::{AtomicU32, Ordering};
7use std::sync::Arc;
8
9use memmap2::MmapMut;
10
11use wayrs_client::global::BindError;
12use wayrs_client::object::Proxy;
13use wayrs_client::protocol::*;
14use wayrs_client::Connection;
15
16#[derive(Debug)]
18pub struct ShmAlloc {
19 state: ShmAllocState,
20}
21
22#[derive(Debug)]
23enum ShmAllocState {
24 Uninit(WlShm),
25 Init(InitShmPool),
26}
27
28#[derive(Debug)]
29struct InitShmPool {
30 pool: WlShmPool,
31 len: usize,
32 file: File,
33 mmap: MmapMut,
34 segments: Vec<Segment>,
35}
36
37#[derive(Debug)]
38struct Segment {
39 offset: usize,
40 len: usize,
41 refcnt: Arc<AtomicU32>,
42 buffer: Option<(WlBuffer, BufferSpec)>,
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub struct BufferSpec {
47 pub width: u32,
48 pub height: u32,
49 pub stride: u32,
50 pub format: wl_shm::Format,
51}
52
53impl BufferSpec {
54 pub fn size(&self) -> usize {
55 self.stride as usize * self.height as usize
56 }
57}
58
59#[derive(Debug)]
61pub struct Buffer {
62 spec: BufferSpec,
63 wl: WlBuffer,
64 refcnt: Arc<AtomicU32>,
65 wl_shm_pool: WlShmPool,
66 offset: usize,
67}
68
69impl ShmAlloc {
70 pub fn bind<D>(conn: &mut Connection<D>) -> Result<Self, BindError> {
72 Ok(Self::new(conn.bind_singleton(1..=2)?))
73 }
74
75 pub fn new(wl_shm: WlShm) -> Self {
79 Self {
80 state: ShmAllocState::Uninit(wl_shm),
81 }
82 }
83
84 pub fn alloc_buffer<D>(
91 &mut self,
92 conn: &mut Connection<D>,
93 spec: BufferSpec,
94 ) -> io::Result<(Buffer, &mut [u8])> {
95 if matches!(&self.state, ShmAllocState::Init(_)) {
97 let ShmAllocState::Init(pool) = &mut self.state else {
98 unreachable!()
99 };
100 return pool.alloc_buffer(conn, spec);
101 }
102
103 let &ShmAllocState::Uninit(wl_shm) = &self.state else {
104 unreachable!()
105 };
106
107 self.state = ShmAllocState::Init(InitShmPool::new(conn, wl_shm, spec.size())?);
108 if wl_shm.version() >= 2 {
109 wl_shm.release(conn);
110 }
111 let ShmAllocState::Init(pool) = &mut self.state else {
112 unreachable!()
113 };
114 pool.alloc_buffer(conn, spec)
115 }
116
117 pub fn destroy<D>(self, conn: &mut Connection<D>) {
119 match self.state {
120 ShmAllocState::Uninit(wl_shm) => {
121 if wl_shm.version() >= 2 {
122 wl_shm.release(conn);
123 }
124 }
125 ShmAllocState::Init(pool) => {
126 pool.pool.destroy(conn);
127 }
128 }
129 }
130}
131
132impl Buffer {
133 #[must_use = "memory is leaked if wl_buffer is not attached"]
138 pub fn into_wl_buffer(self) -> WlBuffer {
139 let wl = self.wl;
140 std::mem::forget(self);
141 wl
142 }
143
144 #[must_use = "memory is leaked if wl_buffer is not attached"]
152 pub fn duplicate<D>(&self, conn: &mut Connection<D>) -> WlBuffer {
153 self.refcnt.fetch_add(1, Ordering::AcqRel);
154 let refcnt = Arc::clone(&self.refcnt);
155 self.wl_shm_pool.create_buffer_with_cb(
156 conn,
157 self.offset as i32,
158 self.spec.width as i32,
159 self.spec.height as i32,
160 self.spec.stride as i32,
161 self.spec.format,
162 move |ctx| {
163 assert!(refcnt.fetch_sub(1, Ordering::AcqRel) > 0);
164 ctx.proxy.destroy(ctx.conn);
165 },
166 )
167 }
168
169 pub fn spec(&self) -> BufferSpec {
171 self.spec
172 }
173}
174
175impl Drop for Buffer {
176 fn drop(&mut self) {
177 assert!(self.refcnt.fetch_sub(1, Ordering::AcqRel) > 0);
178 }
179}
180
181impl InitShmPool {
182 fn new<D>(conn: &mut Connection<D>, wl_shm: WlShm, size: usize) -> io::Result<InitShmPool> {
183 let file = shmemfdrs2::create_shmem(c"/wayrs_shm_pool")?;
184 file.set_len(size as u64)?;
185 let mmap = unsafe { MmapMut::map_mut(&file)? };
186
187 let fd_dup = file
188 .as_fd()
189 .try_clone_to_owned()
190 .expect("could not duplicate fd");
191
192 let pool = wl_shm.create_pool(conn, fd_dup, size as i32);
193
194 Ok(Self {
195 pool,
196 len: size,
197 file,
198 mmap,
199 segments: vec![Segment {
200 offset: 0,
201 len: size,
202 refcnt: Arc::new(AtomicU32::new(0)),
203 buffer: None,
204 }],
205 })
206 }
207
208 fn alloc_buffer<D>(
209 &mut self,
210 conn: &mut Connection<D>,
211 spec: BufferSpec,
212 ) -> io::Result<(Buffer, &mut [u8])> {
213 let segment_index = self.alloc_segment(conn, spec)?;
214 let segment = &mut self.segments[segment_index];
215
216 let (wl, spec) = *segment.buffer.get_or_insert_with(|| {
217 let seg_refcnt = Arc::clone(&segment.refcnt);
218 let wl = self.pool.create_buffer_with_cb(
219 conn,
220 segment.offset as i32,
221 spec.width as i32,
222 spec.height as i32,
223 spec.stride as i32,
224 spec.format,
225 move |_| {
226 assert!(seg_refcnt.fetch_sub(1, Ordering::SeqCst) > 0);
227 },
229 );
230 (wl, spec)
231 });
232
233 Ok((
234 Buffer {
235 spec,
236 wl,
237 refcnt: Arc::clone(&segment.refcnt),
238 wl_shm_pool: self.pool,
239 offset: segment.offset,
240 },
241 &mut self.mmap[segment.offset..][..segment.len],
242 ))
243 }
244
245 fn defragment<D>(&mut self, conn: &mut Connection<D>) {
246 let mut i = 0;
247 while i + 1 < self.segments.len() {
248 if self.segments[i].refcnt.load(Ordering::SeqCst) != 0
251 || self.segments[i + 1].refcnt.load(Ordering::SeqCst) != 0
252 {
253 i += 1;
254 continue;
255 }
256
257 if let Some(buffer) = self.segments[i].buffer.take() {
258 buffer.0.destroy(conn);
259 }
260 if let Some(buffer) = self.segments[i + 1].buffer.take() {
261 buffer.0.destroy(conn);
262 }
263
264 self.segments[i].len += self.segments[i + 1].len;
265
266 self.segments.remove(i + 1);
267 }
268 }
269
270 fn resize<D>(&mut self, conn: &mut Connection<D>, new_len: usize) -> io::Result<()> {
272 if new_len > self.len {
273 self.len = usize::max(self.len * 2, new_len);
274 self.file.set_len(self.len as u64)?;
275 self.pool.resize(conn, self.len as i32);
276 self.mmap = unsafe { MmapMut::map_mut(&self.file)? };
277 }
278 Ok(())
279 }
280
281 fn try_alloc_in_place<D>(
283 &mut self,
284 conn: &mut Connection<D>,
285 len: usize,
286 spec: BufferSpec,
287 ) -> Option<usize> {
288 fn take_if_free(s: &Segment) -> bool {
289 s.refcnt
290 .compare_exchange(0, 1, Ordering::AcqRel, Ordering::Acquire)
291 .is_ok()
292 }
293
294 if let Some((i, segment)) = self
296 .segments
297 .iter_mut()
298 .enumerate()
299 .filter(|(_, s)| s.len == len)
300 .find(|(_, s)| take_if_free(s))
301 {
302 if let Some(buffer) = &segment.buffer {
303 if buffer.1 != spec {
304 buffer.0.destroy(conn);
305 segment.buffer = None;
306 }
307 }
308 return Some(i);
309 }
310
311 if let Some((i, segment)) = self
313 .segments
314 .iter_mut()
315 .enumerate()
316 .filter(|(_, s)| s.len > len)
317 .find(|(_, s)| take_if_free(s))
318 {
319 if let Some(buffer) = segment.buffer.take() {
320 buffer.0.destroy(conn);
321 }
322 let extra = segment.len - len;
323 let offset = segment.offset + len;
324 segment.len = len;
325 self.segments.insert(
326 i + 1,
327 Segment {
328 offset,
329 len: extra,
330 refcnt: Arc::new(AtomicU32::new(0)),
331 buffer: None,
332 },
333 );
334 return Some(i);
335 }
336
337 None
338 }
339
340 fn alloc_segment<D>(
342 &mut self,
343 conn: &mut Connection<D>,
344 spec: BufferSpec,
345 ) -> io::Result<usize> {
346 let len = spec.size();
347
348 if let Some(index) = self.try_alloc_in_place(conn, len, spec) {
349 return Ok(index);
350 }
351
352 self.defragment(conn);
353 if let Some(index) = self.try_alloc_in_place(conn, len, spec) {
354 return Ok(index);
355 }
356
357 let segments_len = match self.segments.last_mut() {
358 Some(segment)
359 if segment
360 .refcnt
361 .compare_exchange(0, 1, Ordering::AcqRel, Ordering::Acquire)
362 .is_ok() =>
363 {
364 if let Some(buffer) = segment.buffer.take() {
365 buffer.0.destroy(conn);
366 }
367 segment.len = len;
368 let new_size = segment.offset + segment.len;
369 self.resize(conn, new_size)?;
370 new_size
371 }
372 _ => {
373 let offset = self.len;
374 self.resize(conn, self.len + len)?;
375 self.segments.push(Segment {
376 offset,
377 len,
378 refcnt: Arc::new(AtomicU32::new(1)),
379 buffer: None,
380 });
381 offset + len
382 }
383 };
384
385 if segments_len > self.len {
387 self.segments.push(Segment {
388 offset: segments_len,
389 len: self.len - segments_len,
390 refcnt: Arc::new(AtomicU32::new(0)),
391 buffer: None,
392 });
393 }
394
395 Ok(self.segments.len() - 1)
396 }
397}