async_niri_socket/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2//! Non-blocking communication over the niri socket.
3
4use futures_lite::{Stream, stream};
5use niri_ipc::{Event, Reply, Request, Response, socket::SOCKET_PATH_ENV};
6use std::{env, io, path::Path};
7
8pub use error::NiriReplyError;
9mod error;
10
11#[cfg(feature = "async-net")]
12mod async_net_executor;
13
14#[cfg(feature = "async-net")]
15#[cfg_attr(docsrs, doc(cfg(feature = "async-net")))]
16pub type AsyncNetSocket = Socket<async_net_executor::AsyncNetStream>;
17
18#[cfg(feature = "tokio")]
19mod tokio_executor;
20
21#[cfg(feature = "tokio")]
22#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
23pub type TokioSocket = Socket<tokio_executor::TokioStream>;
24
25pub struct Socket<S> {
26    stream: S,
27}
28
29trait SocketStream: Sized {
30    async fn connect_to(path: impl AsRef<Path>) -> Result<Self, io::Error>;
31    async fn read_line(&mut self, buf: &mut String) -> Result<(), io::Error>;
32    async fn write_all(&mut self, data: &[u8]) -> Result<(), io::Error>;
33    async fn shutdown_write(&mut self);
34}
35
36#[expect(private_bounds)]
37impl<S: SocketStream> Socket<S> {
38    fn from_stream(stream: S) -> Self {
39        Self { stream }
40    }
41
42    /// Connects to the default niri IPC socket.
43    ///
44    /// This is the async version of [Socket::connect](niri_ipc::socket::Socket::connect)
45    pub async fn connect() -> Result<Self, io::Error> {
46        let socket_path = env::var_os(SOCKET_PATH_ENV).ok_or_else(|| {
47            io::Error::new(
48                io::ErrorKind::NotFound,
49                format!("{SOCKET_PATH_ENV} is not set, are you running this within niri?"),
50            )
51        })?;
52        Self::connect_to(socket_path).await
53    }
54    /// Connects to the niri IPC socket at the given path.
55    ///
56    /// This is the async version of [Socket::connect_to](niri_ipc::socket::Socket::connect_to)
57    pub async fn connect_to(path: impl AsRef<Path>) -> Result<Self, io::Error> {
58        S::connect_to(path).await.map(Self::from_stream)
59    }
60
61    /// Sends a request to niri and returns the response.
62    ///
63    /// This is the async version of [Socket::send](niri_ipc::socket::Socket::send), but with a
64    /// flatten error type [NiriReplyError].
65    pub async fn send(&mut self, request: Request) -> Result<Response, NiriReplyError> {
66        let mut buf = serde_json::to_string(&request).unwrap();
67        buf.push('\n');
68
69        self.stream.write_all(buf.as_bytes()).await?;
70
71        buf.clear();
72        self.stream.read_line(&mut buf).await?;
73
74        serde_json::from_str::<'_, Reply>(&buf)
75            .map_err(io::Error::from)?
76            .map_err(NiriReplyError::Niri)
77    }
78
79    /// Send request and reading event stream [`Event`]s from the socket.
80    ///
81    /// Note that unlike the [`Socket::read_events`](niri_ipc::socket::Socket::read_events),
82    /// this method will send an [`EventStream`][Request::EventStream] request.
83    ///
84    /// # Examples
85    ///
86    /// ```no_run
87    /// use niri_ipc::{Request, Response};
88    /// use niri_ipc::socket::Socket;
89    ///
90    /// async fn print_events() -> Result<(), std::io::Error> {
91    ///     let mut socket = Socket::connect().await?;
92    ///
93    ///     let reply = socket.into_event_stream().await;
94    ///     if let Ok(event_stream) = reply {
95    ///         let read_event = std::pin::pin!(event_stream);
96    ///         while let Some(event) = read_event.next().await {
97    ///             println!("Received event: {event:?}");
98    ///         }
99    ///     }
100    ///
101    ///     Ok(())
102    /// }
103    /// ```
104    pub async fn into_event_stream(
105        mut self,
106    ) -> Result<impl Stream<Item = Result<Event, io::Error>>, NiriReplyError> {
107        self.send(Request::EventStream).await?;
108        let mut stream = self.stream;
109        stream.shutdown_write().await;
110        Ok(Self::get_event_stream(stream))
111    }
112
113    fn get_event_stream(stream: S) -> impl Stream<Item = Result<Event, io::Error>> {
114        stream::unfold(
115            (stream, String::new()),
116            |(mut stream, mut buf)| async move {
117                buf.clear();
118
119                let event = stream
120                    .read_line(&mut buf)
121                    .await
122                    .and_then(|_| serde_json::from_str(&buf).map_err(From::from));
123                Some((event, (stream, buf)))
124            },
125        )
126    }
127}