Skip to main content

wl_proxy/state/
builder.rs

1use {
2    crate::{
3        baseline::Baseline,
4        endpoint::Endpoint,
5        object::{Object, ObjectPrivate},
6        poll::{self, Poller},
7        protocols::wayland::wl_display::WlDisplay,
8        state::{EndpointWithClient, Pollable, State, StateError, StateErrorKind},
9        utils::env::{WAYLAND_DISPLAY, WAYLAND_SOCKET, WL_PROXY_DEBUG, XDG_RUNTIME_DIR},
10    },
11    linearize::Linearize,
12    std::{
13        cell::{Cell, RefCell},
14        collections::HashMap,
15        env::{remove_var, var, var_os},
16        os::{
17            fd::{AsFd, FromRawFd, OwnedFd},
18            unix::ffi::OsStrExt,
19        },
20        rc::Rc,
21        str::FromStr,
22    },
23    uapi::c::{self, sockaddr_un},
24};
25
26/// A builder for a [`State`].
27///
28/// This type can be constructed with [`State::builder`].
29pub struct StateBuilder {
30    baseline: Baseline,
31    server: Option<Server>,
32    log: bool,
33    log_prefix: String,
34}
35
36enum Server {
37    None,
38    Fd(Rc<OwnedFd>),
39    DisplayName(String),
40}
41
42#[derive(Copy, Clone, Linearize)]
43pub(crate) enum StaticPollableIds {
44    Server,
45    Unsuspend,
46}
47
48impl StateBuilder {
49    pub(super) fn new(baseline: Baseline) -> Self {
50        Self {
51            baseline,
52            server: Default::default(),
53            log: var(WL_PROXY_DEBUG).as_deref() == Ok("1"),
54            log_prefix: Default::default(),
55        }
56    }
57
58    /// Builds the state.
59    ///
60    /// The server to connect to is chosen as follows:
61    ///
62    /// - If [`Self::with_server_fd`] was used, that FD is used.
63    /// - Otherwise, if [`Self::with_server_display_name`] was used, that display name is
64    ///   used.
65    /// - Otherwise, if the `WAYLAND_SOCKET` environment variable is set, that FD is used.
66    /// - Otherwise, the display name from the `WAYLAND_DISPLAY` environment variable is
67    ///   used.
68    pub fn build(self) -> Result<Rc<State>, StateError> {
69        let server_fd = 'fd: {
70            let display_name = match self.server {
71                None => None,
72                Some(Server::None) => break 'fd None,
73                Some(Server::Fd(fd)) => break 'fd Some(fd),
74                Some(Server::DisplayName(n)) => Some(n),
75            };
76            if display_name.is_none()
77                && let Some(wayland_socket) = var_os(WAYLAND_SOCKET)
78            {
79                let fd = str::from_utf8(wayland_socket.as_bytes())
80                    .ok()
81                    .and_then(|s| i32::from_str(s).ok())
82                    .ok_or(StateErrorKind::WaylandSocketNotNumber)?;
83                let flags = uapi::fcntl_getfd(fd)
84                    .map_err(|e| StateErrorKind::WaylandSocketGetFd(e.into()))?;
85                uapi::fcntl_setfd(fd, flags | c::FD_CLOEXEC)
86                    .map_err(|e| StateErrorKind::WaylandSocketSetFd(e.into()))?;
87                // SAFETY: This is unsound.
88                let fd = unsafe {
89                    remove_var(WAYLAND_SOCKET);
90                    Rc::new(OwnedFd::from_raw_fd(fd))
91                };
92                break 'fd Some(fd);
93            }
94            let mut name = match display_name {
95                Some(n) => n,
96                _ => var(WAYLAND_DISPLAY)
97                    .ok()
98                    .ok_or(StateErrorKind::WaylandDisplay)?,
99            };
100            if name.is_empty() {
101                return Err(StateErrorKind::WaylandDisplayEmpty.into());
102            }
103            if !name.starts_with("/") {
104                let Ok(xrd) = var(XDG_RUNTIME_DIR) else {
105                    return Err(StateErrorKind::XrdNotSet.into());
106                };
107                name = format!("{xrd}/{name}");
108            }
109            let mut addr = sockaddr_un {
110                sun_family: c::AF_UNIX as _,
111                sun_path: [0; 108],
112            };
113            if name.len() > addr.sun_path.len() - 1 {
114                return Err(StateErrorKind::SocketPathTooLong.into());
115            }
116            let sun_path = uapi::as_bytes_mut(&mut addr.sun_path[..]);
117            sun_path[..name.len()].copy_from_slice(name.as_bytes());
118            sun_path[name.len()] = 0;
119            let socket = uapi::socket(c::AF_UNIX, c::SOCK_STREAM | c::SOCK_CLOEXEC, 0)
120                .map_err(|e| StateErrorKind::CreateSocket(e.into()))?;
121            uapi::connect(socket.raw(), &addr)
122                .map_err(|e| StateErrorKind::Connect(name.to_string(), e.into()))?;
123            Some(Rc::new(socket.into()))
124        };
125        let mut endpoints = HashMap::new();
126        let mut server = None;
127        if let Some(server_fd) = &server_fd {
128            let s = Endpoint::new(StaticPollableIds::Server as u64, server_fd);
129            s.idl.acquire();
130            s.idl.acquire();
131            endpoints.insert(
132                StaticPollableIds::Server as u64,
133                Pollable::Endpoint(EndpointWithClient {
134                    endpoint: s.clone(),
135                    client: None,
136                }),
137            );
138            server = Some(s);
139        }
140        let unsuspend_fd = uapi::eventfd(0, c::EFD_CLOEXEC | c::EFD_NONBLOCK)
141            .map(Into::into)
142            .map_err(|e| StateErrorKind::CreateEventfd(e.into()))?;
143        endpoints.insert(StaticPollableIds::Unsuspend as u64, Pollable::Unsuspend);
144        let poller = Poller::new().map_err(StateErrorKind::PollError)?;
145        #[cfg(feature = "logging")]
146        let log_prefix = {
147            use {crate::utils::env::WL_PROXY_PREFIX, isnt::std_1::string::IsntStringExt};
148            let mut log_prefix = String::new();
149            if let Ok(prefix) = var(WL_PROXY_PREFIX) {
150                log_prefix = prefix;
151            }
152            if self.log_prefix.is_not_empty() {
153                if log_prefix.is_not_empty() {
154                    log_prefix.push_str(" ");
155                }
156                log_prefix.push_str(&self.log_prefix);
157            }
158            if log_prefix.is_not_empty() {
159                log_prefix = format!("{{{}}} ", log_prefix);
160            }
161            log_prefix
162        };
163        let state = Rc::new(State {
164            baseline: self.baseline,
165            poller,
166            next_pollable_id: Cell::new(StaticPollableIds::LENGTH as u64),
167            server,
168            destroyed: Default::default(),
169            handler: Default::default(),
170            pollables: RefCell::new(endpoints),
171            acceptable_acceptors: Default::default(),
172            has_acceptable_acceptors: Default::default(),
173            clients_to_kill: Default::default(),
174            has_clients_to_kill: Default::default(),
175            readable_endpoints: Default::default(),
176            has_readable_endpoints: Default::default(),
177            flushable_endpoints: Default::default(),
178            has_flushable_endpoints: Default::default(),
179            interest_update_endpoints: Default::default(),
180            has_interest_update_endpoints: Default::default(),
181            interest_update_acceptors: Default::default(),
182            has_interest_update_acceptors: Default::default(),
183            all_objects: Default::default(),
184            next_object_id: Cell::new(1),
185            #[cfg(feature = "logging")]
186            log: self.log,
187            #[cfg(feature = "logging")]
188            log_prefix,
189            #[cfg(feature = "logging")]
190            log_writer: RefCell::new(std::io::BufWriter::with_capacity(
191                1024,
192                uapi::Fd::new(c::STDERR_FILENO),
193            )),
194            global_lock_held: Default::default(),
195            object_stash: Default::default(),
196            forward_to_client: Cell::new(true),
197            forward_to_server: Cell::new(true),
198            unsuspend_fd,
199            unsuspend_requests: Default::default(),
200            has_unsuspend_requests: Default::default(),
201            unsuspend_triggered: Default::default(),
202        });
203        if let Some(server) = &state.server {
204            state.change_interest(server, |i| i | poll::READABLE);
205            state
206                .poller
207                .register(server.id, server.socket.as_fd())
208                .map_err(StateErrorKind::PollError)?;
209            let display = WlDisplay::new(&state, 1);
210            display
211                .core()
212                .set_server_id_unchecked(1, display.clone())
213                .unwrap();
214        }
215        state
216            .poller
217            .register_edge_triggered(
218                StaticPollableIds::Unsuspend as u64,
219                state.unsuspend_fd.as_fd(),
220                poll::READABLE,
221            )
222            .map_err(StateErrorKind::PollError)?;
223        Ok(state)
224    }
225
226    /// Constructs a state without a server.
227    pub fn without_server(mut self) -> Self {
228        self.server = Some(Server::None);
229        self
230    }
231
232    /// Sets the server file descriptor to connect to.
233    pub fn with_server_fd(mut self, fd: &Rc<OwnedFd>) -> Self {
234        self.server = Some(Server::Fd(fd.clone()));
235        self
236    }
237
238    /// Sets the server display name to connect to.
239    pub fn with_server_display_name(mut self, name: &str) -> Self {
240        self.server = Some(Server::DisplayName(name.to_owned()));
241        self
242    }
243
244    /// Enables or disables logging.
245    ///
246    /// If this function is not used, then logging is enabled if and only if the
247    /// `WL_PROXY_DEBUG` environment variable is set to `1`.
248    pub fn with_logging(mut self, log: bool) -> Self {
249        self.log = log;
250        self
251    }
252
253    /// Sets a log prefix for messages emitted by this state.
254    pub fn with_log_prefix(mut self, prefix: &str) -> Self {
255        self.log_prefix = prefix.to_string();
256        self
257    }
258}