1#![cfg_attr(docsrs, feature(doc_cfg))]
2use futures_lite::{Stream, stream};
5use niri_ipc::{Event, Reply, Request, Response, socket::SOCKET_PATH_ENV};
6use std::{env, io, path::Path};
7
8pub use error::NiriReplyError;
9mod error;
10
11#[cfg(feature = "async-net")]
12mod async_net_executor;
13
14#[cfg(feature = "async-net")]
15#[cfg_attr(docsrs, doc(cfg(feature = "async-net")))]
16pub type AsyncNetSocket = Socket<async_net_executor::AsyncNetStream>;
17
18#[cfg(feature = "tokio")]
19mod tokio_executor;
20
21#[cfg(feature = "tokio")]
22#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
23pub type TokioSocket = Socket<tokio_executor::TokioStream>;
24
25pub struct Socket<S> {
26 stream: S,
27}
28
29trait SocketStream: Sized {
30 async fn connect_to(path: impl AsRef<Path>) -> Result<Self, io::Error>;
31 async fn read_line(&mut self, buf: &mut String) -> Result<(), io::Error>;
32 async fn write_all(&mut self, data: &[u8]) -> Result<(), io::Error>;
33 async fn shutdown_write(&mut self);
34}
35
36#[expect(private_bounds)]
37impl<S: SocketStream> Socket<S> {
38 fn from_stream(stream: S) -> Self {
39 Self { stream }
40 }
41
42 pub async fn connect() -> Result<Self, io::Error> {
46 let socket_path = env::var_os(SOCKET_PATH_ENV).ok_or_else(|| {
47 io::Error::new(
48 io::ErrorKind::NotFound,
49 format!("{SOCKET_PATH_ENV} is not set, are you running this within niri?"),
50 )
51 })?;
52 Self::connect_to(socket_path).await
53 }
54 pub async fn connect_to(path: impl AsRef<Path>) -> Result<Self, io::Error> {
58 S::connect_to(path).await.map(Self::from_stream)
59 }
60
61 pub async fn send(&mut self, request: Request) -> Result<Response, NiriReplyError> {
66 let mut buf = serde_json::to_string(&request).unwrap();
67 buf.push('\n');
68
69 self.stream.write_all(buf.as_bytes()).await?;
70
71 buf.clear();
72 self.stream.read_line(&mut buf).await?;
73
74 serde_json::from_str::<'_, Reply>(&buf)
75 .map_err(io::Error::from)?
76 .map_err(NiriReplyError::Niri)
77 }
78
79 pub async fn into_event_stream(
105 mut self,
106 ) -> Result<impl Stream<Item = Result<Event, io::Error>>, NiriReplyError> {
107 self.send(Request::EventStream).await?;
108 let mut stream = self.stream;
109 stream.shutdown_write().await;
110 Ok(Self::get_event_stream(stream))
111 }
112
113 fn get_event_stream(stream: S) -> impl Stream<Item = Result<Event, io::Error>> {
114 stream::unfold(
115 (stream, String::new()),
116 |(mut stream, mut buf)| async move {
117 buf.clear();
118
119 let event = stream
120 .read_line(&mut buf)
121 .await
122 .and_then(|_| serde_json::from_str(&buf).map_err(From::from));
123 Some((event, (stream, buf)))
124 },
125 )
126 }
127}