1use futures_lite::{Stream, stream};
4use niri_ipc::{Event, Reply, Request, socket::SOCKET_PATH_ENV};
5use std::{env, io, path::Path};
6
7use error::IntoEventStreamError;
8mod error;
9
10#[cfg(feature = "async-net")]
11mod async_net_executor;
12
13#[cfg(feature = "async-net")]
14pub type AsyncNetSocket = Socket<async_net_executor::AsyncNetStream>;
15
16#[cfg(feature = "tokio")]
17mod tokio_executor;
18
19#[cfg(feature = "tokio")]
20pub type TokioSocket = Socket<tokio_executor::TokioStream>;
21
22pub struct Socket<S> {
23 stream: S,
24}
25
26trait SocketStream: Sized {
27 async fn connect_to(path: impl AsRef<Path>) -> Result<Self, io::Error>;
28 async fn read_line(&mut self, buf: &mut String) -> Result<(), io::Error>;
29 async fn write_all(&mut self, data: &[u8]) -> Result<(), io::Error>;
30 async fn shutdown_write(&mut self);
31}
32
33#[expect(private_bounds)]
34impl<S: SocketStream> Socket<S> {
35 fn from_stream(stream: S) -> Self {
36 Self { stream }
37 }
38
39 pub async fn connect() -> Result<Self, io::Error> {
43 let socket_path = env::var_os(SOCKET_PATH_ENV).ok_or_else(|| {
44 io::Error::new(
45 io::ErrorKind::NotFound,
46 format!("{SOCKET_PATH_ENV} is not set, are you running this within niri?"),
47 )
48 })?;
49 Self::connect_to(socket_path).await
50 }
51 pub async fn connect_to(path: impl AsRef<Path>) -> Result<Self, io::Error> {
55 S::connect_to(path).await.map(Self::from_stream)
56 }
57
58 pub async fn send(&mut self, request: Request) -> Result<Reply, io::Error> {
62 let mut request = serde_json::to_string(&request).unwrap();
63 request.push('\n');
64
65 self.stream.write_all(request.as_bytes()).await?;
66
67 request.clear();
68 self.stream.read_line(&mut request).await?;
69
70 serde_json::from_str(&request).map_err(From::from)
71 }
72
73 pub async fn into_event_stream(
99 mut self,
100 ) -> Result<impl Stream<Item = Result<Event, io::Error>>, IntoEventStreamError> {
101 self.send(Request::EventStream)
102 .await?
103 .map_err(IntoEventStreamError::NiriNotHandled)?;
104 let mut stream = self.stream;
105 stream.shutdown_write().await;
106 Ok(Self::get_event_stream(stream))
107 }
108
109 fn get_event_stream(stream: S) -> impl Stream<Item = Result<Event, io::Error>> {
110 stream::unfold(
111 (stream, String::new()),
112 |(mut stream, mut buf)| async move {
113 buf.clear();
114
115 let event = stream
116 .read_line(&mut buf)
117 .await
118 .and_then(|_| serde_json::from_str(&buf).map_err(From::from));
119 Some((event, (stream, buf)))
120 },
121 )
122 }
123}