async_niri_socket/
lib.rs

1//! Non-blocking communication over the niri socket.
2
3use futures_lite::{Stream, stream};
4use niri_ipc::{Event, Reply, Request, socket::SOCKET_PATH_ENV};
5use std::{env, io, path::Path};
6
7use error::IntoEventStreamError;
8mod error;
9
10#[cfg(feature = "async-net")]
11mod async_net_executor;
12
13#[cfg(feature = "async-net")]
14pub type AsyncNetSocket = Socket<async_net_executor::AsyncNetStream>;
15
16#[cfg(feature = "tokio")]
17mod tokio_executor;
18
19#[cfg(feature = "tokio")]
20pub type TokioSocket = Socket<tokio_executor::TokioStream>;
21
22pub struct Socket<S> {
23    stream: S,
24}
25
26trait SocketStream: Sized {
27    async fn connect_to(path: impl AsRef<Path>) -> Result<Self, io::Error>;
28    async fn read_line(&mut self, buf: &mut String) -> Result<(), io::Error>;
29    async fn write_all(&mut self, data: &[u8]) -> Result<(), io::Error>;
30    async fn shutdown_write(&mut self);
31}
32
33#[expect(private_bounds)]
34impl<S: SocketStream> Socket<S> {
35    fn from_stream(stream: S) -> Self {
36        Self { stream }
37    }
38
39    /// Connects to the default niri IPC socket.
40    ///
41    /// This is the async version of [Socket::connect](niri_ipc::socket::Socket::connect)
42    pub async fn connect() -> Result<Self, io::Error> {
43        let socket_path = env::var_os(SOCKET_PATH_ENV).ok_or_else(|| {
44            io::Error::new(
45                io::ErrorKind::NotFound,
46                format!("{SOCKET_PATH_ENV} is not set, are you running this within niri?"),
47            )
48        })?;
49        Self::connect_to(socket_path).await
50    }
51    /// Connects to the niri IPC socket at the given path.
52    ///
53    /// This is the async version of [Socket::connect_to](niri_ipc::socket::Socket::connect_to)
54    pub async fn connect_to(path: impl AsRef<Path>) -> Result<Self, io::Error> {
55        S::connect_to(path).await.map(Self::from_stream)
56    }
57
58    /// Sends a request to niri and returns the response.
59    ///
60    /// This is the async version of [Socket::send](niri_ipc::socket::Socket::send)
61    pub async fn send(&mut self, request: Request) -> Result<Reply, io::Error> {
62        let mut request = serde_json::to_string(&request).unwrap();
63        request.push('\n');
64
65        self.stream.write_all(request.as_bytes()).await?;
66
67        request.clear();
68        self.stream.read_line(&mut request).await?;
69
70        serde_json::from_str(&request).map_err(From::from)
71    }
72
73    /// Send request and reading event stream [`Event`]s from the socket.
74    ///
75    /// Note that unlike the [`Socket::read_events`](niri_ipc::socket::Socket::read_events),
76    /// this method will send an [`EventStream`][Request::EventStream] request.
77    ///
78    /// # Examples
79    ///
80    /// ```no_run
81    /// use niri_ipc::{Request, Response};
82    /// use niri_ipc::socket::Socket;
83    ///
84    /// async fn print_events() -> Result<(), std::io::Error> {
85    ///     let mut socket = Socket::connect().await?;
86    ///
87    ///     let reply = socket.into_event_stream().await;
88    ///     if let Ok(event_stream) = reply {
89    ///         let read_event = std::pin::pin!(event_stream);
90    ///         while let Some(event) = read_event.next().await {
91    ///             println!("Received event: {event:?}");
92    ///         }
93    ///     }
94    ///
95    ///     Ok(())
96    /// }
97    /// ```
98    pub async fn into_event_stream(
99        mut self,
100    ) -> Result<impl Stream<Item = Result<Event, io::Error>>, IntoEventStreamError> {
101        self.send(Request::EventStream)
102            .await?
103            .map_err(IntoEventStreamError::NiriNotHandled)?;
104        let mut stream = self.stream;
105        stream.shutdown_write().await;
106        Ok(Self::get_event_stream(stream))
107    }
108
109    fn get_event_stream(stream: S) -> impl Stream<Item = Result<Event, io::Error>> {
110        stream::unfold(
111            (stream, String::new()),
112            |(mut stream, mut buf)| async move {
113                buf.clear();
114
115                let event = stream
116                    .read_line(&mut buf)
117                    .await
118                    .and_then(|_| serde_json::from_str(&buf).map_err(From::from));
119                Some((event, (stream, buf)))
120            },
121        )
122    }
123}