1#![cfg_attr(docsrs, feature(doc_cfg))]
2use futures_lite::{Stream, stream};
5use niri_ipc::{Event, Reply, Request, socket::SOCKET_PATH_ENV};
6use std::{env, io, path::Path};
7
8use error::IntoEventStreamError;
9mod error;
10
11#[cfg(feature = "async-net")]
12mod async_net_executor;
13
14#[cfg(feature = "async-net")]
15#[cfg_attr(docsrs, doc(cfg(feature = "async-net")))]
16pub type AsyncNetSocket = Socket<async_net_executor::AsyncNetStream>;
17
18#[cfg(feature = "tokio")]
19mod tokio_executor;
20
21#[cfg(feature = "tokio")]
22#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
23pub type TokioSocket = Socket<tokio_executor::TokioStream>;
24
25pub struct Socket<S> {
26 stream: S,
27}
28
29trait SocketStream: Sized {
30 async fn connect_to(path: impl AsRef<Path>) -> Result<Self, io::Error>;
31 async fn read_line(&mut self, buf: &mut String) -> Result<(), io::Error>;
32 async fn write_all(&mut self, data: &[u8]) -> Result<(), io::Error>;
33 async fn shutdown_write(&mut self);
34}
35
36#[expect(private_bounds)]
37impl<S: SocketStream> Socket<S> {
38 fn from_stream(stream: S) -> Self {
39 Self { stream }
40 }
41
42 pub async fn connect() -> Result<Self, io::Error> {
46 let socket_path = env::var_os(SOCKET_PATH_ENV).ok_or_else(|| {
47 io::Error::new(
48 io::ErrorKind::NotFound,
49 format!("{SOCKET_PATH_ENV} is not set, are you running this within niri?"),
50 )
51 })?;
52 Self::connect_to(socket_path).await
53 }
54 pub async fn connect_to(path: impl AsRef<Path>) -> Result<Self, io::Error> {
58 S::connect_to(path).await.map(Self::from_stream)
59 }
60
61 pub async fn send(&mut self, request: Request) -> Result<Reply, io::Error> {
65 let mut request = serde_json::to_string(&request).unwrap();
66 request.push('\n');
67
68 self.stream.write_all(request.as_bytes()).await?;
69
70 request.clear();
71 self.stream.read_line(&mut request).await?;
72
73 serde_json::from_str(&request).map_err(From::from)
74 }
75
76 pub async fn into_event_stream(
102 mut self,
103 ) -> Result<impl Stream<Item = Result<Event, io::Error>>, IntoEventStreamError> {
104 self.send(Request::EventStream)
105 .await?
106 .map_err(IntoEventStreamError::NiriNotHandled)?;
107 let mut stream = self.stream;
108 stream.shutdown_write().await;
109 Ok(Self::get_event_stream(stream))
110 }
111
112 fn get_event_stream(stream: S) -> impl Stream<Item = Result<Event, io::Error>> {
113 stream::unfold(
114 (stream, String::new()),
115 |(mut stream, mut buf)| async move {
116 buf.clear();
117
118 let event = stream
119 .read_line(&mut buf)
120 .await
121 .and_then(|_| serde_json::from_str(&buf).map_err(From::from));
122 Some((event, (stream, buf)))
123 },
124 )
125 }
126}