Skip to main content

wl_proxy/
state.rs

1//! The proxy state.
2
3use {
4    crate::{
5        acceptor::{Acceptor, AcceptorError},
6        baseline::Baseline,
7        client::Client,
8        endpoint::{Endpoint, EndpointError},
9        handler::HandlerHolder,
10        object::{Object, ObjectCoreApi, ObjectErrorKind, ObjectPrivate},
11        poll::{self, PollError, PollEvent, Poller},
12        protocols::wayland::wl_display::WlDisplay,
13        trans::{FlushResult, TransError},
14        utils::{
15            env::{WAYLAND_DISPLAY, WAYLAND_SOCKET, XDG_RUNTIME_DIR},
16            stack::Stack,
17            stash::Stash,
18        },
19    },
20    error_reporter::Report,
21    run_on_drop::on_drop,
22    std::{
23        cell::{Cell, RefCell},
24        collections::HashMap,
25        io::{self, pipe},
26        os::fd::{AsFd, OwnedFd},
27        rc::{Rc, Weak},
28        sync::{
29            Arc,
30            atomic::{AtomicBool, Ordering::Acquire},
31        },
32        time::Duration,
33    },
34    thiserror::Error,
35    uapi::c,
36};
37pub use {
38    builder::StateBuilder,
39    destructor::{Destructor, RemoteDestructor},
40};
41
42mod builder;
43mod destructor;
44#[cfg(test)]
45mod tests;
46
47/// An error emitted by a [`State`].
48#[derive(Debug, Error)]
49#[error(transparent)]
50pub struct StateError(#[from] StateErrorKind);
51
52#[derive(Debug, Error)]
53enum StateErrorKind {
54    #[error("the state has already been destroyed")]
55    Destroyed,
56    #[error("the state has been destroyed by a remote destructor")]
57    RemoteDestroyed,
58    #[error("cannot perform recursive call into the state")]
59    RecursiveCall,
60    #[error("the server hung up the connection")]
61    ServerHangup,
62    #[error("could not write to the server socket")]
63    WriteToServer(#[source] EndpointError),
64    #[error("could not dispatch server events")]
65    DispatchEvents(#[source] EndpointError),
66    #[error("could not create a socket pair")]
67    Socketpair(#[source] io::Error),
68    #[error(transparent)]
69    CreateAcceptor(AcceptorError),
70    #[error("could not accept a new connection")]
71    AcceptConnection(AcceptorError),
72    #[error("could not create a pipe")]
73    CreatePipe(#[source] io::Error),
74    #[error("could not read {} environment variable", WAYLAND_DISPLAY)]
75    WaylandDisplay,
76    #[error("the display name is empty")]
77    WaylandDisplayEmpty,
78    #[error("{} is not set", XDG_RUNTIME_DIR)]
79    XrdNotSet,
80    #[error("the socket path is too long")]
81    SocketPathTooLong,
82    #[error("could not create a socket")]
83    CreateSocket(#[source] io::Error),
84    #[error("could not connect to {0}")]
85    Connect(String, #[source] io::Error),
86    #[error("{} does not contain a valid number", WAYLAND_SOCKET)]
87    WaylandSocketNotNumber,
88    #[error("F_GETFD failed on {}", WAYLAND_SOCKET)]
89    WaylandSocketGetFd(#[source] io::Error),
90    #[error("F_SETFD failed on {}", WAYLAND_SOCKET)]
91    WaylandSocketSetFd(#[source] io::Error),
92    #[error(transparent)]
93    PollError(PollError),
94}
95
96/// The proxy state.
97///
98/// This type represents a connection to a server and any number of clients connected to
99/// this proxy.
100///
101/// This type can be constructed by using a [`StateBuilder`].
102///
103/// # Example
104///
105/// ```
106/// # use std::rc::Rc;
107/// # use wl_proxy::baseline::Baseline;
108/// # use wl_proxy::client::{Client, ClientHandler};
109/// # use wl_proxy::protocols::wayland::wl_display::{WlDisplay, WlDisplayHandler};
110/// # use wl_proxy::protocols::wayland::wl_registry::WlRegistry;
111/// # use wl_proxy::state::{State, StateBuilder, StateHandler};
112/// # fn f() {
113/// let state = State::builder(Baseline::ALL_OF_THEM).build().unwrap();
114/// let acceptor = state.create_acceptor(1000).unwrap();
115/// eprintln!("{}", acceptor.display());
116/// loop {
117///     state.dispatch_blocking().unwrap();
118/// }
119///
120/// struct StateHandlerImpl;
121///
122/// impl StateHandler for StateHandlerImpl {
123///     fn new_client(&mut self, client: &Rc<Client>) {
124///         eprintln!("Client connected");
125///         client.set_handler(ClientHandlerImpl);
126///         client.display().set_handler(DisplayHandler);
127///     }
128/// }
129///
130/// struct ClientHandlerImpl;
131///
132/// impl ClientHandler for ClientHandlerImpl {
133///     fn disconnected(self: Box<Self>) {
134///         eprintln!("Client disconnected");
135///     }
136/// }
137///
138/// struct DisplayHandler;
139///
140/// impl WlDisplayHandler for DisplayHandler {
141///     fn handle_get_registry(&mut self, slf: &Rc<WlDisplay>, registry: &Rc<WlRegistry>) {
142///         eprintln!("get_registry called");
143///         let _ = slf.send_get_registry(registry);
144///     }
145/// }
146/// # }
147/// ```
148pub struct State {
149    pub(crate) baseline: Baseline,
150    poller: Poller,
151    next_pollable_id: Cell<u64>,
152    pub(crate) server: Option<Rc<Endpoint>>,
153    pub(crate) destroyed: Cell<bool>,
154    handler: HandlerHolder<dyn StateHandler>,
155    pollables: RefCell<HashMap<u64, Pollable>>,
156    acceptable_acceptors: Stack<Rc<Acceptor>>,
157    has_acceptable_acceptors: Cell<bool>,
158    clients_to_kill: Stack<Rc<Client>>,
159    has_clients_to_kill: Cell<bool>,
160    readable_endpoints: Stack<EndpointWithClient>,
161    has_readable_endpoints: Cell<bool>,
162    flushable_endpoints: Stack<EndpointWithClient>,
163    has_flushable_endpoints: Cell<bool>,
164    interest_update_endpoints: Stack<Rc<Endpoint>>,
165    has_interest_update_endpoints: Cell<bool>,
166    interest_update_acceptors: Stack<Rc<Acceptor>>,
167    has_interest_update_acceptors: Cell<bool>,
168    pub(crate) all_objects: RefCell<HashMap<u64, Weak<dyn Object>>>,
169    pub(crate) next_object_id: Cell<u64>,
170    #[cfg(feature = "logging")]
171    pub(crate) log: bool,
172    #[cfg(feature = "logging")]
173    pub(crate) log_prefix: String,
174    #[cfg(feature = "logging")]
175    log_writer: RefCell<io::BufWriter<uapi::Fd>>,
176    global_lock_held: Cell<bool>,
177    pub(crate) object_stash: Stash<Rc<dyn Object>>,
178    pub(crate) forward_to_client: Cell<bool>,
179    pub(crate) forward_to_server: Cell<bool>,
180}
181
182/// A handler for events emitted by a [`State`].
183pub trait StateHandler: 'static {
184    /// A new client has connected.
185    ///
186    /// This event is not emitted if the connection is created explicitly via
187    /// [`State::connect`] or [`State::add_client`].
188    fn new_client(&mut self, client: &Rc<Client>) {
189        let _ = client;
190    }
191
192    /// The server has sent a wl_display.error event.
193    ///
194    /// Such errors are fatal.
195    ///
196    /// The object can be `None` if the error is sent on an object that has already been
197    /// deleted.
198    fn display_error(
199        self: Box<Self>,
200        object: Option<&Rc<dyn Object>>,
201        server_id: u32,
202        error: u32,
203        msg: &str,
204    ) {
205        let _ = object;
206        let _ = server_id;
207        let _ = error;
208        let _ = msg;
209    }
210}
211
212enum Pollable {
213    Endpoint(EndpointWithClient),
214    Acceptor(Rc<Acceptor>),
215    Destructor(OwnedFd, Arc<AtomicBool>),
216}
217
218#[derive(Clone)]
219struct EndpointWithClient {
220    endpoint: Rc<Endpoint>,
221    client: Option<Rc<Client>>,
222}
223
224pub(crate) struct HandlerLock<'a> {
225    state: &'a State,
226}
227
228impl State {
229    pub(crate) fn remove_endpoint(&self, endpoint: &Endpoint) {
230        self.pollables.borrow_mut().remove(&endpoint.id);
231        self.poller.unregister(endpoint.socket.as_fd());
232        endpoint.unregistered.set(true);
233    }
234
235    fn acquire_handler_lock(&self) -> Result<HandlerLock<'_>, StateErrorKind> {
236        if self.global_lock_held.replace(true) {
237            return Err(StateErrorKind::RecursiveCall);
238        }
239        Ok(HandlerLock { state: self })
240    }
241
242    fn flush_locked(&self, lock: &HandlerLock<'_>) -> Result<bool, StateError> {
243        let mut did_work = false;
244        did_work |= self.perform_writes(lock)?;
245        did_work |= self.kill_clients();
246        self.update_interests()?;
247        Ok(did_work)
248    }
249
250    pub(crate) fn handle_delete_id(&self, server: &Endpoint, id: u32) {
251        let object = server.objects.borrow_mut().remove(&id).unwrap();
252        let core = object.core();
253        core.server_obj_id.take();
254        server.idl.release(id);
255        if let Err((e, object)) = object.delete_id() {
256            log::warn!(
257                "Could not handle a wl_display.delete_id message: {}",
258                Report::new(e),
259            );
260            let _ = object.core().try_delete_id();
261        }
262    }
263
264    fn perform_writes(&self, _: &HandlerLock<'_>) -> Result<bool, StateError> {
265        if !self.has_flushable_endpoints.get() {
266            return Ok(false);
267        }
268        while let Some(ewc) = self.flushable_endpoints.pop() {
269            let res = match ewc.endpoint.flush() {
270                Ok(r) => r,
271                Err(e) => {
272                    let is_closed = matches!(e, EndpointError::Flush(TransError::Closed));
273                    if let Some(client) = &ewc.client {
274                        if !is_closed {
275                            log::warn!(
276                                "Could not write to client#{}: {}",
277                                client.endpoint.id,
278                                Report::new(e),
279                            );
280                        }
281                        self.add_client_to_kill(client);
282                    } else {
283                        if is_closed {
284                            return Err(StateErrorKind::ServerHangup.into());
285                        }
286                        return Err(StateErrorKind::WriteToServer(e).into());
287                    }
288                    continue;
289                }
290            };
291            match res {
292                FlushResult::Done => {
293                    ewc.endpoint.flush_queued.set(false);
294                    self.change_interest(&ewc.endpoint, |i| i & !poll::WRITABLE);
295                }
296                FlushResult::Blocked => {
297                    self.change_interest(&ewc.endpoint, |i| i | poll::WRITABLE);
298                }
299            }
300        }
301        self.has_flushable_endpoints.set(false);
302        Ok(true)
303    }
304
305    fn accept_connections(self: &Rc<Self>, lock: &HandlerLock<'_>) -> Result<bool, StateError> {
306        if !self.has_acceptable_acceptors.get() {
307            return Ok(false);
308        }
309        self.check_destroyed()?;
310        while let Some(acceptor) = self.acceptable_acceptors.pop() {
311            self.interest_update_acceptors.push(acceptor.clone());
312            self.has_interest_update_acceptors.set(true);
313            const MAX_ACCEPT_PER_ITERATION: usize = 10;
314            for _ in 0..MAX_ACCEPT_PER_ITERATION {
315                let socket = acceptor
316                    .accept()
317                    .map_err(StateErrorKind::AcceptConnection)?;
318                let Some(socket) = socket else {
319                    break;
320                };
321                self.create_client(Some(lock), &Rc::new(socket))?;
322            }
323        }
324        self.has_acceptable_acceptors.set(false);
325        Ok(true)
326    }
327
328    fn read_messages(&self, lock: &HandlerLock<'_>) -> Result<bool, StateError> {
329        if !self.has_readable_endpoints.get() {
330            return Ok(false);
331        }
332        while let Some(ewc) = self.readable_endpoints.pop() {
333            let res = ewc.endpoint.read_messages(lock, ewc.client.as_ref());
334            if let Err(e) = res {
335                if let Some(client) = &ewc.client {
336                    log::error!("Could not handle client message: {}", Report::new(e));
337                    self.add_client_to_kill(client);
338                } else {
339                    if let EndpointError::HandleMessage(msg) = &e
340                        && let ObjectErrorKind::ServerError(object, server_id, error, msg) =
341                            &msg.source.0
342                        && let Some(handler) = self.handler.borrow_mut().take()
343                    {
344                        handler.display_error(object.as_ref(), *server_id, *error, &msg.0)
345                    }
346                    return Err(StateErrorKind::DispatchEvents(e).into());
347                }
348            }
349            self.change_interest(&ewc.endpoint, |i| i | poll::READABLE);
350        }
351        self.has_readable_endpoints.set(false);
352        Ok(true)
353    }
354
355    fn change_interest(&self, endpoint: &Rc<Endpoint>, f: impl FnOnce(u32) -> u32) {
356        if self.destroyed.get() {
357            return;
358        }
359        let old = endpoint.desired_interest.get();
360        let new = f(old);
361        endpoint.desired_interest.set(new);
362        if old != new
363            && endpoint.current_interest.get() != new
364            && !endpoint.interest_update_queued.replace(true)
365        {
366            self.interest_update_endpoints.push(endpoint.clone());
367            self.has_interest_update_endpoints.set(true);
368        }
369    }
370
371    pub(crate) fn add_flushable_endpoint(
372        &self,
373        endpoint: &Rc<Endpoint>,
374        client: Option<&Rc<Client>>,
375    ) {
376        if self.destroyed.get() {
377            return;
378        }
379        self.flushable_endpoints.push(EndpointWithClient {
380            endpoint: endpoint.clone(),
381            client: client.cloned(),
382        });
383        self.has_flushable_endpoints.set(true);
384    }
385
386    fn wait_for_work(&self, _: &HandlerLock<'_>, mut timeout: c::c_int) -> Result<(), StateError> {
387        self.check_destroyed()?;
388        let mut events = [PollEvent::default(); poll::MAX_EVENTS];
389        let pollables = &mut *self.pollables.borrow_mut();
390        loop {
391            let n = self
392                .poller
393                .read_events(timeout, &mut events)
394                .map_err(StateErrorKind::PollError)?;
395            if n == 0 {
396                return Ok(());
397            }
398            timeout = 0;
399            for event in &events[0..n] {
400                let id = event.u64;
401                let Some(pollable) = pollables.get(&id) else {
402                    continue;
403                };
404                match pollable {
405                    Pollable::Endpoint(ewc) => {
406                        let events = event.events;
407                        if events & poll::ERROR != 0 {
408                            if let Some(client) = &ewc.client {
409                                self.add_client_to_kill(client);
410                            } else {
411                                return Err(StateErrorKind::ServerHangup.into());
412                            }
413                            continue;
414                        }
415                        ewc.endpoint.current_interest.set(0);
416                        self.change_interest(&ewc.endpoint, |i| i & !events);
417                        if events & poll::READABLE != 0 {
418                            self.readable_endpoints.push(ewc.clone());
419                            self.has_readable_endpoints.set(true);
420                        }
421                        if events & poll::WRITABLE != 0 {
422                            self.flushable_endpoints.push(ewc.clone());
423                            self.has_flushable_endpoints.set(true);
424                        }
425                    }
426                    Pollable::Acceptor(a) => {
427                        self.acceptable_acceptors.push(a.clone());
428                        self.has_acceptable_acceptors.set(true);
429                    }
430                    Pollable::Destructor(fd, destroy) => {
431                        let destroy = destroy.load(Acquire);
432                        self.poller.unregister(fd.as_fd());
433                        pollables.remove(&id);
434                        if destroy {
435                            return Err(StateErrorKind::RemoteDestroyed.into());
436                        }
437                    }
438                }
439            }
440        }
441    }
442
443    fn add_client_to_kill(&self, client: &Rc<Client>) {
444        self.clients_to_kill.push(client.clone());
445        self.has_clients_to_kill.set(true);
446    }
447
448    fn kill_clients(&self) -> bool {
449        if !self.has_clients_to_kill.get() {
450            return false;
451        }
452        while let Some(client) = self.clients_to_kill.pop() {
453            if let Some(handler) = client.handler.borrow_mut().take() {
454                handler.disconnected();
455            }
456            client.disconnect();
457        }
458        self.has_clients_to_kill.set(false);
459        true
460    }
461
462    fn create_pollable_id(&self) -> u64 {
463        let id = self.next_pollable_id.get();
464        self.next_pollable_id.set(id + 1);
465        id
466    }
467
468    fn update_interests(&self) -> Result<(), StateError> {
469        if self.has_interest_update_endpoints.get() {
470            while let Some(endpoint) = self.interest_update_endpoints.pop() {
471                endpoint.interest_update_queued.set(false);
472                let desired = endpoint.desired_interest.get();
473                if desired == endpoint.current_interest.get() {
474                    continue;
475                }
476                if endpoint.unregistered.get() {
477                    continue;
478                }
479                self.poller
480                    .update_interests(endpoint.id, endpoint.socket.as_fd(), desired)
481                    .map_err(StateErrorKind::PollError)?;
482                endpoint.current_interest.set(desired);
483            }
484            self.has_interest_update_endpoints.set(false);
485        }
486        if self.has_interest_update_acceptors.get() {
487            while let Some(acceptor) = self.interest_update_acceptors.pop() {
488                self.poller
489                    .update_interests(acceptor.id, acceptor.socket.as_fd(), poll::READABLE)
490                    .map_err(StateErrorKind::PollError)?;
491            }
492            self.has_interest_update_acceptors.set(false);
493        }
494        Ok(())
495    }
496
497    fn check_destroyed(&self) -> Result<(), StateError> {
498        if self.destroyed.get() {
499            return Err(StateErrorKind::Destroyed.into());
500        }
501        Ok(())
502    }
503
504    #[cfg(feature = "logging")]
505    #[cold]
506    pub(crate) fn log(&self, args: std::fmt::Arguments<'_>) {
507        use std::io::Write;
508        let writer = &mut *self.log_writer.borrow_mut();
509        let _ = writer.write_fmt(args);
510        let _ = writer.flush();
511    }
512}
513
514/// These functions can be used to create a new state.
515impl State {
516    /// Creates a new [`StateBuilder`].
517    pub fn builder(baseline: Baseline) -> StateBuilder {
518        StateBuilder::new(baseline)
519    }
520}
521
522/// These functions can be used to dispatch and flush messages.
523impl State {
524    /// Performs a blocking dispatch.
525    ///
526    /// This is a shorthand for `self.dispatch(None)`.
527    pub fn dispatch_blocking(self: &Rc<Self>) -> Result<bool, StateError> {
528        self.dispatch(None)
529    }
530
531    /// Performs a non-blocking dispatch.
532    ///
533    /// This is a shorthand for `self.dispatch(Some(Duration::from_secs(0))`.
534    pub fn dispatch_available(self: &Rc<Self>) -> Result<bool, StateError> {
535        self.dispatch(Some(Duration::from_secs(0)))
536    }
537
538    /// Performs a dispatch.
539    ///
540    /// The timeout determines how long this function will wait for new events. If the
541    /// timeout is `None`, then it will wait indefinitely. If the timeout is `0`, then
542    /// it will only process currently available events.
543    ///
544    /// If the timeout is not `0`, then outgoing messages will be flushed before waiting.
545    ///
546    /// Outgoing messages will be flushed immediately before this function returns.
547    ///
548    /// The return value indicates if any work was performed.
549    ///
550    /// This function is not reentrant. It should not be called from within a callback.
551    /// Trying to do so will cause it to return an error immediately and the state will
552    /// be otherwise unchanged.
553    pub fn dispatch(self: &Rc<Self>, timeout: Option<Duration>) -> Result<bool, StateError> {
554        let mut did_work = false;
555        let lock = self.acquire_handler_lock()?;
556        let timeout = timeout
557            .and_then(|t| t.as_millis().try_into().ok())
558            .unwrap_or(-1);
559        let destroy_on_error = on_drop(|| self.destroy());
560        if timeout != 0 {
561            did_work |= self.flush_locked(&lock)?;
562        }
563        self.wait_for_work(&lock, timeout)?;
564        did_work |= self.accept_connections(&lock)?;
565        did_work |= self.read_messages(&lock)?;
566        did_work |= self.flush_locked(&lock)?;
567        destroy_on_error.forget();
568        Ok(did_work)
569    }
570}
571
572impl State {
573    /// Returns a file descriptor that can be used with epoll or similar.
574    ///
575    /// If this file descriptor becomes readable, the state should be dispatched.
576    /// [`Self::before_poll`] should be used before going to sleep.
577    ///
578    /// This function always returns the same file descriptor.
579    pub fn poll_fd(&self) -> &Rc<OwnedFd> {
580        self.poller.fd()
581    }
582
583    /// Prepares the state for an external poll operation.
584    ///
585    /// If [`Self::poll_fd`] is used, this function should be called immediately before
586    /// going to sleep. Otherwise, outgoing messages might not be flushed.
587    ///
588    /// ```
589    /// # use std::os::fd::OwnedFd;
590    /// # use std::rc::Rc;
591    /// # use wl_proxy::state::State;
592    /// # fn poll(fd: &OwnedFd) { }
593    /// # fn f(state: &Rc<State>) {
594    /// loop {
595    ///     state.before_poll().unwrap();
596    ///     poll(state.poll_fd());
597    ///     state.dispatch_available().unwrap();
598    /// }
599    /// # }
600    /// ```
601    pub fn before_poll(&self) -> Result<(), StateError> {
602        let lock = self.acquire_handler_lock()?;
603        let destroy_on_error = on_drop(|| self.destroy());
604        self.flush_locked(&lock)?;
605        destroy_on_error.forget();
606        Ok(())
607    }
608}
609
610/// These functions can be used to manipulate objects.
611impl State {
612    /// Creates a new object.
613    ///
614    /// The new object is not associated with a client ID or a server ID. It can become
615    /// associated with a client ID by sending an event with a `new_id` parameter. It can
616    /// become associated with a server ID by sending a request with a `new_id` parameter.
617    ///
618    /// The object can only be associated with one client at a time. The association with
619    /// a client is removed when the object is used in a destructor event.
620    ///
621    /// This function does not enforce that the version is less than or equal to the
622    /// maximum version supported by this crate. Using a version that exceeds tha maximum
623    /// supported version can cause a protocol error if the client sends a request that is
624    /// not available in the maximum supported protocol version or if the server sends an
625    /// event that is not available in the maximum supported protocol version.
626    pub fn create_object<P>(self: &Rc<Self>, version: u32) -> Rc<P>
627    where
628        P: Object,
629    {
630        P::new(self, version)
631    }
632
633    /// Returns a wl_display object.
634    pub fn display(self: &Rc<Self>) -> Rc<WlDisplay> {
635        let display = WlDisplay::new(self, 1);
636        if self.server.is_some() {
637            display.core().server_obj_id.set(Some(1));
638        }
639        display
640    }
641
642    /// Changes the default forward-to-client setting.
643    ///
644    /// This affects objects created after this call. See
645    /// [`ObjectCoreApi::set_forward_to_client`].
646    pub fn set_default_forward_to_client(&self, enabled: bool) {
647        self.forward_to_client.set(enabled);
648    }
649
650    /// Changes the default forward-to-server setting.
651    ///
652    /// This affects objects created after this call. See
653    /// [`ObjectCoreApi::set_forward_to_server`].
654    pub fn set_default_forward_to_server(&self, enabled: bool) {
655        self.forward_to_server.set(enabled);
656    }
657}
658
659/// These functions can be used to manage sockets associated with this state.
660impl State {
661    /// Creates a new connection to this proxy.
662    ///
663    /// The returned file descriptor is the client end of the connection and can be used
664    /// with a function such as `wl_display_connect_to_fd` or with the `WAYLAND_SOCKET`
665    /// environment variable.
666    ///
667    /// The [`StateHandler::new_client`] callback will not be invoked.
668    pub fn connect(self: &Rc<Self>) -> Result<(Rc<Client>, OwnedFd), StateError> {
669        let (server_fd, client_fd) = uapi::socketpair(
670            c::AF_UNIX,
671            c::SOCK_STREAM | c::SOCK_NONBLOCK | c::SOCK_CLOEXEC,
672            0,
673        )
674        .map_err(|e| StateErrorKind::Socketpair(e.into()))?;
675        let client = self.create_client(None, &Rc::new(server_fd.into()))?;
676        Ok((client, client_fd.into()))
677    }
678
679    /// Creates a new connection to this proxy from an existing socket.
680    ///
681    /// The file descriptor should be the server end of the connection. It can be created
682    /// with a function such as `socketpair` or by accepting a connection from a
683    /// file-system socket.
684    ///
685    /// The [`StateHandler::new_client`] callback will not be invoked.
686    pub fn add_client(self: &Rc<Self>, socket: &Rc<OwnedFd>) -> Result<Rc<Client>, StateError> {
687        self.create_client(None, socket)
688    }
689
690    /// Creates a new file-system acceptor and starts listening for connections.
691    ///
692    /// See [`Acceptor::new`] for the meaning of the `max_tries` parameter.
693    ///
694    /// Calling [`State::dispatch`] will automatically accept connections from this
695    /// acceptor. The [`StateHandler::new_client`] callback will be invoked when this
696    /// happens.
697    pub fn create_acceptor(&self, max_tries: u32) -> Result<Rc<Acceptor>, StateError> {
698        self.check_destroyed()?;
699        let id = self.create_pollable_id();
700        let acceptor =
701            Acceptor::create(id, max_tries, true).map_err(StateErrorKind::CreateAcceptor)?;
702        self.poller
703            .register(id, acceptor.socket.as_fd())
704            .map_err(StateErrorKind::PollError)?;
705        self.update_interests()?;
706        self.interest_update_acceptors.push(acceptor.clone());
707        self.has_interest_update_acceptors.set(true);
708        self.pollables
709            .borrow_mut()
710            .insert(id, Pollable::Acceptor(acceptor.clone()));
711        Ok(acceptor)
712    }
713
714    fn create_client(
715        self: &Rc<Self>,
716        lock: Option<&HandlerLock<'_>>,
717        socket: &Rc<OwnedFd>,
718    ) -> Result<Rc<Client>, StateError> {
719        self.check_destroyed()?;
720        let id = self.create_pollable_id();
721        self.poller
722            .register(id, socket.as_fd())
723            .map_err(StateErrorKind::PollError)?;
724        let endpoint = Endpoint::new(id, socket);
725        self.change_interest(&endpoint, |i| i | poll::READABLE);
726        self.update_interests()?;
727        let client = Rc::new(Client {
728            state: self.clone(),
729            endpoint: endpoint.clone(),
730            display: self.display(),
731            destroyed: Cell::new(false),
732            handler: Default::default(),
733        });
734        client
735            .display
736            .core()
737            .set_client_id(&client, 1, client.display.clone())
738            .unwrap();
739        self.pollables.borrow_mut().insert(
740            id,
741            Pollable::Endpoint(EndpointWithClient {
742                endpoint,
743                client: Some(client.clone()),
744            }),
745        );
746        if lock.is_some()
747            && let Some(handler) = &mut *self.handler.borrow_mut()
748        {
749            handler.new_client(&client);
750        }
751        Ok(client)
752    }
753}
754
755/// These functions can be used to manipulate the [`StateHandler`] of this state.
756///
757/// These functions can be called at any time, even from within a handler callback. In
758/// that case, the handler is replaced as soon as the callback returns.
759impl State {
760    /// Unsets the handler.
761    pub fn unset_handler(&self) {
762        self.handler.set(None);
763    }
764
765    /// Sets a new handler.
766    pub fn set_handler(&self, handler: impl StateHandler) {
767        self.set_boxed_handler(Box::new(handler))
768    }
769
770    /// Sets a new, already boxed handler.
771    pub fn set_boxed_handler(&self, handler: Box<dyn StateHandler>) {
772        if self.destroyed.get() {
773            return;
774        }
775        self.handler.set(Some(handler));
776    }
777}
778
779/// These functions can be used to check the state status and to destroy the state.
780impl State {
781    /// Returns whether this state is not destroyed.
782    ///
783    /// This is the same as `!self.is_destroyed()`.
784    pub fn is_not_destroyed(&self) -> bool {
785        !self.is_destroyed()
786    }
787
788    /// Returns whether the state is destroyed.
789    ///
790    /// If the state is destroyed, most functions that can return an error will return an
791    /// error saying that the state is already destroyed.
792    ///
793    /// This function or [`Self::is_not_destroyed`] should be used before dispatching the
794    /// state.
795    ///
796    /// # Example
797    ///
798    /// ```
799    /// # use std::rc::Rc;
800    /// # use error_reporter::Report;
801    /// # use wl_proxy::state::State;
802    /// #
803    /// # fn f(state: &Rc<State>) {
804    /// while state.is_not_destroyed() {
805    ///     if let Err(e) = state.dispatch_blocking() {
806    ///         log::error!("Could not dispatch the state: {}", Report::new(e));
807    ///     }
808    /// }
809    /// # }
810    /// ```
811    pub fn is_destroyed(&self) -> bool {
812        self.destroyed.get()
813    }
814
815    /// Destroys this state.
816    ///
817    /// This function unsets all handlers and destroys all clients. You should drop the
818    /// state after calling this function.
819    pub fn destroy(&self) {
820        if self.destroyed.replace(true) {
821            return;
822        }
823        let objects = &mut *self.object_stash.borrow();
824        for pollable in self.pollables.borrow().values() {
825            let fd = match pollable {
826                Pollable::Endpoint(ewc) => {
827                    if let Some(c) = &ewc.client {
828                        c.destroyed.set(true);
829                    }
830                    objects.extend(ewc.endpoint.objects.borrow_mut().drain().map(|v| v.1));
831                    &ewc.endpoint.socket
832                }
833                Pollable::Acceptor(a) => &a.socket,
834                Pollable::Destructor(fd, _) => fd,
835            };
836            self.poller.unregister(fd.as_fd());
837        }
838        objects.clear();
839        for object in self.all_objects.borrow().values() {
840            if let Some(object) = object.upgrade() {
841                objects.push(object);
842            }
843        }
844        for object in objects {
845            object.unset_handler();
846            object.core().client.take();
847        }
848        self.handler.set(None);
849        self.pollables.borrow_mut().clear();
850        self.acceptable_acceptors.take();
851        self.clients_to_kill.take();
852        self.readable_endpoints.take();
853        self.flushable_endpoints.take();
854        self.interest_update_endpoints.take();
855        self.interest_update_acceptors.take();
856        self.all_objects.borrow_mut().clear();
857        // Ensure that the poll fd stays permanently readable.
858        let _ = self.create_remote_destructor();
859    }
860
861    /// Creates a RAII destructor for this state.
862    ///
863    /// Dropping the destructor will automatically call [`State::destroy`] unless you
864    /// first call [`Destructor::disable`].
865    ///
866    /// State objects contain reference cycles that must be cleared manually to release
867    /// the associated resources. Dropping the [`State`] is usually not sufficient to do
868    /// this. Instead, [`State::destroy`] must be called manually. This function can be
869    /// used to accomplish this in an application that otherwise relies on RAII semantics.
870    ///
871    /// Ensure that the destructor is itself not part of a reference cycle.
872    pub fn create_destructor(self: &Rc<Self>) -> Destructor {
873        Destructor {
874            state: self.clone(),
875            enabled: Cell::new(true),
876        }
877    }
878
879    /// Creates a `Sync+Send` RAII destructor for this state.
880    ///
881    /// This function is similar to [`State::create_destructor`] but the returned
882    /// destructor implements `Sync+Send`. This destructor can therefore be used to
883    /// destroy states running in a different thread.
884    pub fn create_remote_destructor(&self) -> Result<RemoteDestructor, StateError> {
885        let (r, w) = pipe().map_err(StateErrorKind::CreatePipe)?;
886        let r: OwnedFd = r.into();
887        let id = self.create_pollable_id();
888        self.poller
889            .register(id, r.as_fd())
890            .map_err(StateErrorKind::PollError)?;
891        let destroy = Arc::new(AtomicBool::new(false));
892        self.pollables
893            .borrow_mut()
894            .insert(id, Pollable::Destructor(r, destroy.clone()));
895        Ok(RemoteDestructor {
896            destroy,
897            _fd: w.into(),
898            enabled: AtomicBool::new(true),
899        })
900    }
901}
902
903impl StateError {
904    /// Returns whether this error was emitted because the state is already destroyed.
905    ///
906    /// This can be used to determine the severity of emitted log messages.
907    pub fn is_destroyed(&self) -> bool {
908        matches!(self.0, StateErrorKind::Destroyed)
909    }
910}
911
912impl Drop for HandlerLock<'_> {
913    fn drop(&mut self) {
914        self.state.global_lock_held.set(false);
915    }
916}