async_niri_socket/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2//! Non-blocking communication over the niri socket.
3
4use futures_lite::{Stream, stream};
5use niri_ipc::{Event, Reply, Request, socket::SOCKET_PATH_ENV};
6use std::{env, io, path::Path};
7
8use error::IntoEventStreamError;
9mod error;
10
11#[cfg(feature = "async-net")]
12mod async_net_executor;
13
14#[cfg(feature = "async-net")]
15#[cfg_attr(docsrs, doc(cfg(feature = "async-net")))]
16pub type AsyncNetSocket = Socket<async_net_executor::AsyncNetStream>;
17
18#[cfg(feature = "tokio")]
19mod tokio_executor;
20
21#[cfg(feature = "tokio")]
22#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
23pub type TokioSocket = Socket<tokio_executor::TokioStream>;
24
25pub struct Socket<S> {
26    stream: S,
27}
28
29trait SocketStream: Sized {
30    async fn connect_to(path: impl AsRef<Path>) -> Result<Self, io::Error>;
31    async fn read_line(&mut self, buf: &mut String) -> Result<(), io::Error>;
32    async fn write_all(&mut self, data: &[u8]) -> Result<(), io::Error>;
33    async fn shutdown_write(&mut self);
34}
35
36#[expect(private_bounds)]
37impl<S: SocketStream> Socket<S> {
38    fn from_stream(stream: S) -> Self {
39        Self { stream }
40    }
41
42    /// Connects to the default niri IPC socket.
43    ///
44    /// This is the async version of [Socket::connect](niri_ipc::socket::Socket::connect)
45    pub async fn connect() -> Result<Self, io::Error> {
46        let socket_path = env::var_os(SOCKET_PATH_ENV).ok_or_else(|| {
47            io::Error::new(
48                io::ErrorKind::NotFound,
49                format!("{SOCKET_PATH_ENV} is not set, are you running this within niri?"),
50            )
51        })?;
52        Self::connect_to(socket_path).await
53    }
54    /// Connects to the niri IPC socket at the given path.
55    ///
56    /// This is the async version of [Socket::connect_to](niri_ipc::socket::Socket::connect_to)
57    pub async fn connect_to(path: impl AsRef<Path>) -> Result<Self, io::Error> {
58        S::connect_to(path).await.map(Self::from_stream)
59    }
60
61    /// Sends a request to niri and returns the response.
62    ///
63    /// This is the async version of [Socket::send](niri_ipc::socket::Socket::send)
64    pub async fn send(&mut self, request: Request) -> Result<Reply, io::Error> {
65        let mut request = serde_json::to_string(&request).unwrap();
66        request.push('\n');
67
68        self.stream.write_all(request.as_bytes()).await?;
69
70        request.clear();
71        self.stream.read_line(&mut request).await?;
72
73        serde_json::from_str(&request).map_err(From::from)
74    }
75
76    /// Send request and reading event stream [`Event`]s from the socket.
77    ///
78    /// Note that unlike the [`Socket::read_events`](niri_ipc::socket::Socket::read_events),
79    /// this method will send an [`EventStream`][Request::EventStream] request.
80    ///
81    /// # Examples
82    ///
83    /// ```no_run
84    /// use niri_ipc::{Request, Response};
85    /// use niri_ipc::socket::Socket;
86    ///
87    /// async fn print_events() -> Result<(), std::io::Error> {
88    ///     let mut socket = Socket::connect().await?;
89    ///
90    ///     let reply = socket.into_event_stream().await;
91    ///     if let Ok(event_stream) = reply {
92    ///         let read_event = std::pin::pin!(event_stream);
93    ///         while let Some(event) = read_event.next().await {
94    ///             println!("Received event: {event:?}");
95    ///         }
96    ///     }
97    ///
98    ///     Ok(())
99    /// }
100    /// ```
101    pub async fn into_event_stream(
102        mut self,
103    ) -> Result<impl Stream<Item = Result<Event, io::Error>>, IntoEventStreamError> {
104        self.send(Request::EventStream)
105            .await?
106            .map_err(IntoEventStreamError::NiriNotHandled)?;
107        let mut stream = self.stream;
108        stream.shutdown_write().await;
109        Ok(Self::get_event_stream(stream))
110    }
111
112    fn get_event_stream(stream: S) -> impl Stream<Item = Result<Event, io::Error>> {
113        stream::unfold(
114            (stream, String::new()),
115            |(mut stream, mut buf)| async move {
116                buf.clear();
117
118                let event = stream
119                    .read_line(&mut buf)
120                    .await
121                    .and_then(|_| serde_json::from_str(&buf).map_err(From::from));
122                Some((event, (stream, buf)))
123            },
124        )
125    }
126}