1use {
4 crate::{
5 acceptor::{Acceptor, AcceptorError},
6 baseline::Baseline,
7 client::Client,
8 endpoint::{Endpoint, EndpointError},
9 handler::HandlerHolder,
10 object::{Object, ObjectCoreApi, ObjectErrorKind, ObjectPrivate},
11 poll::{self, PollError, PollEvent, Poller},
12 protocols::wayland::wl_display::WlDisplay,
13 trans::{FlushResult, TransError},
14 utils::{
15 env::{WAYLAND_DISPLAY, WAYLAND_SOCKET, XDG_RUNTIME_DIR},
16 stack::Stack,
17 stash::Stash,
18 },
19 },
20 error_reporter::Report,
21 run_on_drop::on_drop,
22 std::{
23 cell::{Cell, RefCell},
24 collections::HashMap,
25 io::{self, pipe},
26 os::fd::{AsFd, OwnedFd},
27 rc::{Rc, Weak},
28 sync::{
29 Arc,
30 atomic::{AtomicBool, Ordering::Acquire},
31 },
32 time::Duration,
33 },
34 thiserror::Error,
35 uapi::c,
36};
37pub use {
38 builder::StateBuilder,
39 destructor::{Destructor, RemoteDestructor},
40};
41
42mod builder;
43mod destructor;
44#[cfg(test)]
45mod tests;
46
47#[derive(Debug, Error)]
49#[error(transparent)]
50pub struct StateError(#[from] StateErrorKind);
51
52#[derive(Debug, Error)]
53enum StateErrorKind {
54 #[error("the state has already been destroyed")]
55 Destroyed,
56 #[error("the state has been destroyed by a remote destructor")]
57 RemoteDestroyed,
58 #[error("cannot perform recursive call into the state")]
59 RecursiveCall,
60 #[error("the server hung up the connection")]
61 ServerHangup,
62 #[error("could not write to the server socket")]
63 WriteToServer(#[source] EndpointError),
64 #[error("could not dispatch server events")]
65 DispatchEvents(#[source] EndpointError),
66 #[error("could not create a socket pair")]
67 Socketpair(#[source] io::Error),
68 #[error(transparent)]
69 CreateAcceptor(AcceptorError),
70 #[error("could not accept a new connection")]
71 AcceptConnection(AcceptorError),
72 #[error("could not create a pipe")]
73 CreatePipe(#[source] io::Error),
74 #[error("could not read {} environment variable", WAYLAND_DISPLAY)]
75 WaylandDisplay,
76 #[error("the display name is empty")]
77 WaylandDisplayEmpty,
78 #[error("{} is not set", XDG_RUNTIME_DIR)]
79 XrdNotSet,
80 #[error("the socket path is too long")]
81 SocketPathTooLong,
82 #[error("could not create a socket")]
83 CreateSocket(#[source] io::Error),
84 #[error("could not connect to {0}")]
85 Connect(String, #[source] io::Error),
86 #[error("{} does not contain a valid number", WAYLAND_SOCKET)]
87 WaylandSocketNotNumber,
88 #[error("F_GETFD failed on {}", WAYLAND_SOCKET)]
89 WaylandSocketGetFd(#[source] io::Error),
90 #[error("F_SETFD failed on {}", WAYLAND_SOCKET)]
91 WaylandSocketSetFd(#[source] io::Error),
92 #[error(transparent)]
93 PollError(PollError),
94}
95
96pub struct State {
149 pub(crate) baseline: Baseline,
150 poller: Poller,
151 next_pollable_id: Cell<u64>,
152 pub(crate) server: Option<Rc<Endpoint>>,
153 pub(crate) destroyed: Cell<bool>,
154 handler: HandlerHolder<dyn StateHandler>,
155 pollables: RefCell<HashMap<u64, Pollable>>,
156 acceptable_acceptors: Stack<Rc<Acceptor>>,
157 has_acceptable_acceptors: Cell<bool>,
158 clients_to_kill: Stack<Rc<Client>>,
159 has_clients_to_kill: Cell<bool>,
160 readable_endpoints: Stack<EndpointWithClient>,
161 has_readable_endpoints: Cell<bool>,
162 flushable_endpoints: Stack<EndpointWithClient>,
163 has_flushable_endpoints: Cell<bool>,
164 interest_update_endpoints: Stack<Rc<Endpoint>>,
165 has_interest_update_endpoints: Cell<bool>,
166 interest_update_acceptors: Stack<Rc<Acceptor>>,
167 has_interest_update_acceptors: Cell<bool>,
168 pub(crate) all_objects: RefCell<HashMap<u64, Weak<dyn Object>>>,
169 pub(crate) next_object_id: Cell<u64>,
170 #[cfg(feature = "logging")]
171 pub(crate) log: bool,
172 #[cfg(feature = "logging")]
173 pub(crate) log_prefix: String,
174 #[cfg(feature = "logging")]
175 log_writer: RefCell<io::BufWriter<uapi::Fd>>,
176 global_lock_held: Cell<bool>,
177 pub(crate) object_stash: Stash<Rc<dyn Object>>,
178 pub(crate) forward_to_client: Cell<bool>,
179 pub(crate) forward_to_server: Cell<bool>,
180}
181
182pub trait StateHandler: 'static {
184 fn new_client(&mut self, client: &Rc<Client>) {
189 let _ = client;
190 }
191
192 fn display_error(
199 self: Box<Self>,
200 object: Option<&Rc<dyn Object>>,
201 server_id: u32,
202 error: u32,
203 msg: &str,
204 ) {
205 let _ = object;
206 let _ = server_id;
207 let _ = error;
208 let _ = msg;
209 }
210}
211
212enum Pollable {
213 Endpoint(EndpointWithClient),
214 Acceptor(Rc<Acceptor>),
215 Destructor(OwnedFd, Arc<AtomicBool>),
216}
217
218#[derive(Clone)]
219struct EndpointWithClient {
220 endpoint: Rc<Endpoint>,
221 client: Option<Rc<Client>>,
222}
223
224pub(crate) struct HandlerLock<'a> {
225 state: &'a State,
226}
227
228impl State {
229 pub(crate) fn remove_endpoint(&self, endpoint: &Endpoint) {
230 self.pollables.borrow_mut().remove(&endpoint.id);
231 self.poller.unregister(endpoint.socket.as_fd());
232 endpoint.unregistered.set(true);
233 }
234
235 fn acquire_handler_lock(&self) -> Result<HandlerLock<'_>, StateErrorKind> {
236 if self.global_lock_held.replace(true) {
237 return Err(StateErrorKind::RecursiveCall);
238 }
239 Ok(HandlerLock { state: self })
240 }
241
242 fn flush_locked(&self, lock: &HandlerLock<'_>) -> Result<bool, StateError> {
243 let mut did_work = false;
244 did_work |= self.perform_writes(lock)?;
245 did_work |= self.kill_clients();
246 self.update_interests()?;
247 Ok(did_work)
248 }
249
250 pub(crate) fn handle_delete_id(&self, server: &Endpoint, id: u32) {
251 let object = server.objects.borrow_mut().remove(&id).unwrap();
252 let core = object.core();
253 core.server_obj_id.take();
254 server.idl.release(id);
255 if let Err((e, object)) = object.delete_id() {
256 log::warn!(
257 "Could not handle a wl_display.delete_id message: {}",
258 Report::new(e),
259 );
260 let _ = object.core().try_delete_id();
261 }
262 }
263
264 fn perform_writes(&self, _: &HandlerLock<'_>) -> Result<bool, StateError> {
265 if !self.has_flushable_endpoints.get() {
266 return Ok(false);
267 }
268 while let Some(ewc) = self.flushable_endpoints.pop() {
269 let res = match ewc.endpoint.flush() {
270 Ok(r) => r,
271 Err(e) => {
272 let is_closed = matches!(e, EndpointError::Flush(TransError::Closed));
273 if let Some(client) = &ewc.client {
274 if !is_closed {
275 log::warn!(
276 "Could not write to client#{}: {}",
277 client.endpoint.id,
278 Report::new(e),
279 );
280 }
281 self.add_client_to_kill(client);
282 } else {
283 if is_closed {
284 return Err(StateErrorKind::ServerHangup.into());
285 }
286 return Err(StateErrorKind::WriteToServer(e).into());
287 }
288 continue;
289 }
290 };
291 match res {
292 FlushResult::Done => {
293 ewc.endpoint.flush_queued.set(false);
294 self.change_interest(&ewc.endpoint, |i| i & !poll::WRITABLE);
295 }
296 FlushResult::Blocked => {
297 self.change_interest(&ewc.endpoint, |i| i | poll::WRITABLE);
298 }
299 }
300 }
301 self.has_flushable_endpoints.set(false);
302 Ok(true)
303 }
304
305 fn accept_connections(self: &Rc<Self>, lock: &HandlerLock<'_>) -> Result<bool, StateError> {
306 if !self.has_acceptable_acceptors.get() {
307 return Ok(false);
308 }
309 self.check_destroyed()?;
310 while let Some(acceptor) = self.acceptable_acceptors.pop() {
311 self.interest_update_acceptors.push(acceptor.clone());
312 self.has_interest_update_acceptors.set(true);
313 const MAX_ACCEPT_PER_ITERATION: usize = 10;
314 for _ in 0..MAX_ACCEPT_PER_ITERATION {
315 let socket = acceptor
316 .accept()
317 .map_err(StateErrorKind::AcceptConnection)?;
318 let Some(socket) = socket else {
319 break;
320 };
321 self.create_client(Some(lock), &Rc::new(socket))?;
322 }
323 }
324 self.has_acceptable_acceptors.set(false);
325 Ok(true)
326 }
327
328 fn read_messages(&self, lock: &HandlerLock<'_>) -> Result<bool, StateError> {
329 if !self.has_readable_endpoints.get() {
330 return Ok(false);
331 }
332 while let Some(ewc) = self.readable_endpoints.pop() {
333 let res = ewc.endpoint.read_messages(lock, ewc.client.as_ref());
334 if let Err(e) = res {
335 if let Some(client) = &ewc.client {
336 log::error!("Could not handle client message: {}", Report::new(e));
337 self.add_client_to_kill(client);
338 } else {
339 if let EndpointError::HandleMessage(msg) = &e
340 && let ObjectErrorKind::ServerError(object, server_id, error, msg) =
341 &msg.source.0
342 && let Some(handler) = self.handler.borrow_mut().take()
343 {
344 handler.display_error(object.as_ref(), *server_id, *error, &msg.0)
345 }
346 return Err(StateErrorKind::DispatchEvents(e).into());
347 }
348 }
349 self.change_interest(&ewc.endpoint, |i| i | poll::READABLE);
350 }
351 self.has_readable_endpoints.set(false);
352 Ok(true)
353 }
354
355 fn change_interest(&self, endpoint: &Rc<Endpoint>, f: impl FnOnce(u32) -> u32) {
356 if self.destroyed.get() {
357 return;
358 }
359 let old = endpoint.desired_interest.get();
360 let new = f(old);
361 endpoint.desired_interest.set(new);
362 if old != new
363 && endpoint.current_interest.get() != new
364 && !endpoint.interest_update_queued.replace(true)
365 {
366 self.interest_update_endpoints.push(endpoint.clone());
367 self.has_interest_update_endpoints.set(true);
368 }
369 }
370
371 pub(crate) fn add_flushable_endpoint(
372 &self,
373 endpoint: &Rc<Endpoint>,
374 client: Option<&Rc<Client>>,
375 ) {
376 if self.destroyed.get() {
377 return;
378 }
379 self.flushable_endpoints.push(EndpointWithClient {
380 endpoint: endpoint.clone(),
381 client: client.cloned(),
382 });
383 self.has_flushable_endpoints.set(true);
384 }
385
386 fn wait_for_work(&self, _: &HandlerLock<'_>, mut timeout: c::c_int) -> Result<(), StateError> {
387 self.check_destroyed()?;
388 let mut events = [PollEvent::default(); poll::MAX_EVENTS];
389 let pollables = &mut *self.pollables.borrow_mut();
390 loop {
391 let n = self
392 .poller
393 .read_events(timeout, &mut events)
394 .map_err(StateErrorKind::PollError)?;
395 if n == 0 {
396 return Ok(());
397 }
398 timeout = 0;
399 for event in &events[0..n] {
400 let id = event.u64;
401 let Some(pollable) = pollables.get(&id) else {
402 continue;
403 };
404 match pollable {
405 Pollable::Endpoint(ewc) => {
406 let events = event.events;
407 if events & poll::ERROR != 0 {
408 if let Some(client) = &ewc.client {
409 self.add_client_to_kill(client);
410 } else {
411 return Err(StateErrorKind::ServerHangup.into());
412 }
413 continue;
414 }
415 ewc.endpoint.current_interest.set(0);
416 self.change_interest(&ewc.endpoint, |i| i & !events);
417 if events & poll::READABLE != 0 {
418 self.readable_endpoints.push(ewc.clone());
419 self.has_readable_endpoints.set(true);
420 }
421 if events & poll::WRITABLE != 0 {
422 self.flushable_endpoints.push(ewc.clone());
423 self.has_flushable_endpoints.set(true);
424 }
425 }
426 Pollable::Acceptor(a) => {
427 self.acceptable_acceptors.push(a.clone());
428 self.has_acceptable_acceptors.set(true);
429 }
430 Pollable::Destructor(fd, destroy) => {
431 let destroy = destroy.load(Acquire);
432 self.poller.unregister(fd.as_fd());
433 pollables.remove(&id);
434 if destroy {
435 return Err(StateErrorKind::RemoteDestroyed.into());
436 }
437 }
438 }
439 }
440 }
441 }
442
443 fn add_client_to_kill(&self, client: &Rc<Client>) {
444 self.clients_to_kill.push(client.clone());
445 self.has_clients_to_kill.set(true);
446 }
447
448 fn kill_clients(&self) -> bool {
449 if !self.has_clients_to_kill.get() {
450 return false;
451 }
452 while let Some(client) = self.clients_to_kill.pop() {
453 if let Some(handler) = client.handler.borrow_mut().take() {
454 handler.disconnected();
455 }
456 client.disconnect();
457 }
458 self.has_clients_to_kill.set(false);
459 true
460 }
461
462 fn create_pollable_id(&self) -> u64 {
463 let id = self.next_pollable_id.get();
464 self.next_pollable_id.set(id + 1);
465 id
466 }
467
468 fn update_interests(&self) -> Result<(), StateError> {
469 if self.has_interest_update_endpoints.get() {
470 while let Some(endpoint) = self.interest_update_endpoints.pop() {
471 endpoint.interest_update_queued.set(false);
472 let desired = endpoint.desired_interest.get();
473 if desired == endpoint.current_interest.get() {
474 continue;
475 }
476 if endpoint.unregistered.get() {
477 continue;
478 }
479 self.poller
480 .update_interests(endpoint.id, endpoint.socket.as_fd(), desired)
481 .map_err(StateErrorKind::PollError)?;
482 endpoint.current_interest.set(desired);
483 }
484 self.has_interest_update_endpoints.set(false);
485 }
486 if self.has_interest_update_acceptors.get() {
487 while let Some(acceptor) = self.interest_update_acceptors.pop() {
488 self.poller
489 .update_interests(acceptor.id, acceptor.socket.as_fd(), poll::READABLE)
490 .map_err(StateErrorKind::PollError)?;
491 }
492 self.has_interest_update_acceptors.set(false);
493 }
494 Ok(())
495 }
496
497 fn check_destroyed(&self) -> Result<(), StateError> {
498 if self.destroyed.get() {
499 return Err(StateErrorKind::Destroyed.into());
500 }
501 Ok(())
502 }
503
504 #[cfg(feature = "logging")]
505 #[cold]
506 pub(crate) fn log(&self, args: std::fmt::Arguments<'_>) {
507 use std::io::Write;
508 let writer = &mut *self.log_writer.borrow_mut();
509 let _ = writer.write_fmt(args);
510 let _ = writer.flush();
511 }
512}
513
514impl State {
516 pub fn builder(baseline: Baseline) -> StateBuilder {
518 StateBuilder::new(baseline)
519 }
520}
521
522impl State {
524 pub fn dispatch_blocking(self: &Rc<Self>) -> Result<bool, StateError> {
528 self.dispatch(None)
529 }
530
531 pub fn dispatch_available(self: &Rc<Self>) -> Result<bool, StateError> {
535 self.dispatch(Some(Duration::from_secs(0)))
536 }
537
538 pub fn dispatch(self: &Rc<Self>, timeout: Option<Duration>) -> Result<bool, StateError> {
554 let mut did_work = false;
555 let lock = self.acquire_handler_lock()?;
556 let timeout = timeout
557 .and_then(|t| t.as_millis().try_into().ok())
558 .unwrap_or(-1);
559 let destroy_on_error = on_drop(|| self.destroy());
560 if timeout != 0 {
561 did_work |= self.flush_locked(&lock)?;
562 }
563 self.wait_for_work(&lock, timeout)?;
564 did_work |= self.accept_connections(&lock)?;
565 did_work |= self.read_messages(&lock)?;
566 did_work |= self.flush_locked(&lock)?;
567 destroy_on_error.forget();
568 Ok(did_work)
569 }
570}
571
572impl State {
573 pub fn poll_fd(&self) -> &Rc<OwnedFd> {
580 self.poller.fd()
581 }
582
583 pub fn before_poll(&self) -> Result<(), StateError> {
602 let lock = self.acquire_handler_lock()?;
603 let destroy_on_error = on_drop(|| self.destroy());
604 self.flush_locked(&lock)?;
605 destroy_on_error.forget();
606 Ok(())
607 }
608}
609
610impl State {
612 pub fn create_object<P>(self: &Rc<Self>, version: u32) -> Rc<P>
627 where
628 P: Object,
629 {
630 P::new(self, version)
631 }
632
633 pub fn display(self: &Rc<Self>) -> Rc<WlDisplay> {
635 let display = WlDisplay::new(self, 1);
636 if self.server.is_some() {
637 display.core().server_obj_id.set(Some(1));
638 }
639 display
640 }
641
642 pub fn set_default_forward_to_client(&self, enabled: bool) {
647 self.forward_to_client.set(enabled);
648 }
649
650 pub fn set_default_forward_to_server(&self, enabled: bool) {
655 self.forward_to_server.set(enabled);
656 }
657}
658
659impl State {
661 pub fn connect(self: &Rc<Self>) -> Result<(Rc<Client>, OwnedFd), StateError> {
669 let (server_fd, client_fd) = uapi::socketpair(
670 c::AF_UNIX,
671 c::SOCK_STREAM | c::SOCK_NONBLOCK | c::SOCK_CLOEXEC,
672 0,
673 )
674 .map_err(|e| StateErrorKind::Socketpair(e.into()))?;
675 let client = self.create_client(None, &Rc::new(server_fd.into()))?;
676 Ok((client, client_fd.into()))
677 }
678
679 pub fn add_client(self: &Rc<Self>, socket: &Rc<OwnedFd>) -> Result<Rc<Client>, StateError> {
687 self.create_client(None, socket)
688 }
689
690 pub fn create_acceptor(&self, max_tries: u32) -> Result<Rc<Acceptor>, StateError> {
698 self.check_destroyed()?;
699 let id = self.create_pollable_id();
700 let acceptor =
701 Acceptor::create(id, max_tries, true).map_err(StateErrorKind::CreateAcceptor)?;
702 self.poller
703 .register(id, acceptor.socket.as_fd())
704 .map_err(StateErrorKind::PollError)?;
705 self.update_interests()?;
706 self.interest_update_acceptors.push(acceptor.clone());
707 self.has_interest_update_acceptors.set(true);
708 self.pollables
709 .borrow_mut()
710 .insert(id, Pollable::Acceptor(acceptor.clone()));
711 Ok(acceptor)
712 }
713
714 fn create_client(
715 self: &Rc<Self>,
716 lock: Option<&HandlerLock<'_>>,
717 socket: &Rc<OwnedFd>,
718 ) -> Result<Rc<Client>, StateError> {
719 self.check_destroyed()?;
720 let id = self.create_pollable_id();
721 self.poller
722 .register(id, socket.as_fd())
723 .map_err(StateErrorKind::PollError)?;
724 let endpoint = Endpoint::new(id, socket);
725 self.change_interest(&endpoint, |i| i | poll::READABLE);
726 self.update_interests()?;
727 let client = Rc::new(Client {
728 state: self.clone(),
729 endpoint: endpoint.clone(),
730 display: self.display(),
731 destroyed: Cell::new(false),
732 handler: Default::default(),
733 });
734 client
735 .display
736 .core()
737 .set_client_id(&client, 1, client.display.clone())
738 .unwrap();
739 self.pollables.borrow_mut().insert(
740 id,
741 Pollable::Endpoint(EndpointWithClient {
742 endpoint,
743 client: Some(client.clone()),
744 }),
745 );
746 if lock.is_some()
747 && let Some(handler) = &mut *self.handler.borrow_mut()
748 {
749 handler.new_client(&client);
750 }
751 Ok(client)
752 }
753}
754
755impl State {
760 pub fn unset_handler(&self) {
762 self.handler.set(None);
763 }
764
765 pub fn set_handler(&self, handler: impl StateHandler) {
767 self.set_boxed_handler(Box::new(handler))
768 }
769
770 pub fn set_boxed_handler(&self, handler: Box<dyn StateHandler>) {
772 if self.destroyed.get() {
773 return;
774 }
775 self.handler.set(Some(handler));
776 }
777}
778
779impl State {
781 pub fn is_not_destroyed(&self) -> bool {
785 !self.is_destroyed()
786 }
787
788 pub fn is_destroyed(&self) -> bool {
812 self.destroyed.get()
813 }
814
815 pub fn destroy(&self) {
820 if self.destroyed.replace(true) {
821 return;
822 }
823 let objects = &mut *self.object_stash.borrow();
824 for pollable in self.pollables.borrow().values() {
825 let fd = match pollable {
826 Pollable::Endpoint(ewc) => {
827 if let Some(c) = &ewc.client {
828 c.destroyed.set(true);
829 }
830 objects.extend(ewc.endpoint.objects.borrow_mut().drain().map(|v| v.1));
831 &ewc.endpoint.socket
832 }
833 Pollable::Acceptor(a) => &a.socket,
834 Pollable::Destructor(fd, _) => fd,
835 };
836 self.poller.unregister(fd.as_fd());
837 }
838 objects.clear();
839 for object in self.all_objects.borrow().values() {
840 if let Some(object) = object.upgrade() {
841 objects.push(object);
842 }
843 }
844 for object in objects {
845 object.unset_handler();
846 object.core().client.take();
847 }
848 self.handler.set(None);
849 self.pollables.borrow_mut().clear();
850 self.acceptable_acceptors.take();
851 self.clients_to_kill.take();
852 self.readable_endpoints.take();
853 self.flushable_endpoints.take();
854 self.interest_update_endpoints.take();
855 self.interest_update_acceptors.take();
856 self.all_objects.borrow_mut().clear();
857 let _ = self.create_remote_destructor();
859 }
860
861 pub fn create_destructor(self: &Rc<Self>) -> Destructor {
873 Destructor {
874 state: self.clone(),
875 enabled: Cell::new(true),
876 }
877 }
878
879 pub fn create_remote_destructor(&self) -> Result<RemoteDestructor, StateError> {
885 let (r, w) = pipe().map_err(StateErrorKind::CreatePipe)?;
886 let r: OwnedFd = r.into();
887 let id = self.create_pollable_id();
888 self.poller
889 .register(id, r.as_fd())
890 .map_err(StateErrorKind::PollError)?;
891 let destroy = Arc::new(AtomicBool::new(false));
892 self.pollables
893 .borrow_mut()
894 .insert(id, Pollable::Destructor(r, destroy.clone()));
895 Ok(RemoteDestructor {
896 destroy,
897 _fd: w.into(),
898 enabled: AtomicBool::new(true),
899 })
900 }
901}
902
903impl StateError {
904 pub fn is_destroyed(&self) -> bool {
908 matches!(self.0, StateErrorKind::Destroyed)
909 }
910}
911
912impl Drop for HandlerLock<'_> {
913 fn drop(&mut self) {
914 self.state.global_lock_held.set(false);
915 }
916}