1use std::sync::Arc;
2
3use serde::{Serialize, de::DeserializeOwned};
4use serde_json::Value;
5#[cfg(unix)]
6use tokio::net::UnixStream;
7#[cfg(windows)]
8use tokio::net::windows::named_pipe::ClientOptions;
9use tokio::{
10 io::{AsyncRead, AsyncWrite},
11 net::TcpStream,
12 sync::{mpsc, oneshot},
13};
14use uuid::Uuid;
15
16use crate::{
17 broker::{read_packet, write_packet},
18 rpc::{CallId, RpcRequest, RpcResponse},
19};
20
21pub trait AsyncStream: AsyncRead + AsyncWrite {}
24impl<T: AsyncRead + AsyncWrite + Unpin> AsyncStream for T {}
25
26enum ClientMsg {
31 Request {
33 req: RpcRequest,
34 resp_tx: oneshot::Sender<std::io::Result<RpcResponse>>,
35 },
36 Subscribe {
38 object_name: String,
39 topic: String,
40 updates: mpsc::UnboundedSender<serde_json::Value>,
41 },
42}
43
44#[derive(Clone)]
51pub struct IPCClient {
52 tx: mpsc::UnboundedSender<ClientMsg>,
53}
54
55impl IPCClient {
56 pub async fn connect() -> std::io::Result<Self> {
65 let stream: Box<dyn AsyncStream + Send + Unpin> =
67 if let Ok(ip) = std::env::var("BROKER_ADDR") {
68 let tcp = TcpStream::connect(ip.as_str()).await?;
69 log::info!("Connected into TCP: {ip}");
70 Box::new(tcp)
71 } else {
72 #[cfg(unix)]
74 {
75 use crate::rpc::UNIX_PATH;
76
77 let unix = UnixStream::connect(UNIX_PATH).await?;
78 log::info!("Connected into Unix: {UNIX_PATH}");
79 Box::new(unix)
80 }
81
82 #[cfg(windows)]
83 {
84 use crate::rpc::PIPE_PATH;
85 loop {
86 let res = match ClientOptions::new().open(PIPE_PATH) {
87 Ok(pipe) => {
88 log::info!("Connected into NamedPipe: {PIPE_PATH}");
89 Box::new(pipe)
90 }
91 Err(e) if e.raw_os_error() == Some(231) => {
92 use std::time::Duration;
95
96 log::error!("All pipe instances busy, retrying...");
97 tokio::time::sleep(Duration::from_millis(100)).await;
98 continue;
99 }
100 Err(e) => {
101 use std::time::Duration;
102 log::error!("Failed to connect to pipe: {}", e);
103 tokio::time::sleep(Duration::from_millis(100)).await;
104 continue;
105 }
106 };
107 break res;
108 }
109 }
110 };
111
112 let (tx, mut rx) = mpsc::unbounded_channel::<ClientMsg>();
114
115 tokio::spawn(async move {
117 let mut stream = stream;
118 let mut subs: std::collections::HashMap<
119 (String, String),
120 Vec<mpsc::UnboundedSender<serde_json::Value>>,
121 > = std::collections::HashMap::new();
122
123 loop {
124 tokio::select! {
125 Some(msg) = rx.recv() => {
126 match msg {
127 ClientMsg::Request { req, resp_tx } => {
128 let data = match serde_json::to_vec(&req) {
130 Ok(d) => d,
131 Err(e) => {
132 let _ = resp_tx.send(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e)));
133 continue;
134 }
135 };
136
137 if let Err(e) = write_packet(&mut stream, &data).await {
138 let _ = resp_tx.send(Err(e));
139 continue;
140 }
141
142 match req {
144 RpcRequest::Call { .. } | RpcRequest::RegisterObject { .. } | RpcRequest::HasObject { .. } => {
145 match read_packet(&mut stream).await {
147 Ok(data) => {
148 let resp: Result<RpcResponse, _> = serde_json::from_slice(&data);
149 match resp {
150 Ok(r) => { let _ = resp_tx.send(Ok(r)); }
151 Err(e) => {
152 let _ = resp_tx.send(Err(std::io::Error::new(
153 std::io::ErrorKind::InvalidData,
154 e,
155 )));
156 }
157 }
158 }
159 Err(e) => {
160 let _ = resp_tx.send(Err(e));
161 }
162 }
163 }
164 RpcRequest::Publish { .. } | RpcRequest::Subscribe { .. } => {
165 log::trace!("Fire-and-forget: do not await a response");
166 let _ = resp_tx.send(Ok(RpcResponse::Event {
167 object_name: "".into(),
168 topic: "".into(),
169 args: serde_json::Value::Null,
170 }));
171 }
172 }
173 }
174 ClientMsg::Subscribe { object_name, topic, updates } => {
175 log::debug!("Client subscribing to {object_name}/{topic}");
176 subs.entry((object_name.clone(), topic.clone()))
177 .or_default()
178 .push(updates);
179 let data = serde_json::to_vec(&RpcRequest::Subscribe { object_name, topic }).unwrap();
180
181 let _ = write_packet(&mut stream, &data).await;
182 }
183
184 }
185 }
186 Ok(data) = read_packet(&mut stream) => {
187 if data.is_empty() {
188 break;
189 }
190 log::debug!("Data {}", String::from_utf8_lossy(&data));
191 match serde_json::from_slice::<RpcResponse>(&data) {
192 Ok(resp) => {
193 if let RpcResponse::Event { object_name, topic, args } = resp {
195 if let Some(subscribers) = subs.get(&(object_name.clone(), topic.clone())) {
196 for tx in subscribers {
197 let _ = tx.send(args.clone());
198 }
199 }
200 } else{
201 log::trace!("Other responses are ignored here; handled elsewhere");
202 }
203 continue;
204 }
205 Err(_) => {
206 log::trace!("Partial JSON, fallthrough to buffer handling");
207 }
208 }
209 }
210 }
211 }
212 });
213
214 Ok(Self { tx })
215 }
216
217 pub async fn remote_call<U, T>(&self, object: &str, method: &str, args: U) -> std::io::Result<T>
226 where
227 T: DeserializeOwned,
228 U: Serialize + std::any::Any + 'static,
229 {
230 let args = if let Some(val) = (&args as &dyn std::any::Any).downcast_ref::<Value>() {
231 val.clone()
232 } else {
233 serde_json::to_value(args)?
234 };
235 let call_id = CallId::from(Uuid::new_v4());
236
237 let req = RpcRequest::Call {
238 call_id,
239 object_name: object.into(),
240 method: method.into(),
241 args,
242 };
243
244 let (resp_tx, resp_rx) = oneshot::channel();
245 let msg = ClientMsg::Request { req, resp_tx };
246
247 self.tx
249 .send(msg)
250 .map_err(|_| std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Actor dropped"))?;
251
252 let resp = resp_rx.await.unwrap_or_else(|_| {
254 Err(std::io::Error::new(
255 std::io::ErrorKind::ConnectionAborted,
256 "Actor task ended",
257 ))
258 })?;
259
260 match resp {
262 RpcResponse::Result { value, .. } => serde_json::from_value(value).map_err(|e| {
263 std::io::Error::new(
264 std::io::ErrorKind::InvalidData,
265 format!("Deserialize error: {e}"),
266 )
267 }),
268 RpcResponse::Error { message, .. } => {
269 Err(std::io::Error::other(format!("Remote error: {message}")))
270 }
271 _ => Err(std::io::Error::new(
272 std::io::ErrorKind::InvalidData,
273 "Unexpected response type",
274 )),
275 }
276 }
277
278 pub async fn publish(
282 &self,
283 object: &str,
284 topic: &str,
285 args: &serde_json::Value,
286 ) -> std::io::Result<()> {
287 let (resp_tx, _resp_rx) = oneshot::channel();
288 let msg = ClientMsg::Request {
289 req: RpcRequest::Publish {
290 object_name: object.into(),
291 topic: topic.into(),
292 args: args.clone(),
293 },
294 resp_tx,
295 };
296
297 self.tx
298 .send(msg)
299 .map_err(|_| std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Actor dropped"))
300 }
301
302 pub async fn subscribe(
307 &self,
308 object: &str,
309 topic: &str,
310 ) -> mpsc::UnboundedReceiver<serde_json::Value> {
311 let (tx_updates, rx_updates) = mpsc::unbounded_channel();
312 let _ = self.tx.send(ClientMsg::Subscribe {
313 object_name: object.into(),
314 topic: topic.into(),
315 updates: tx_updates,
316 });
317 rx_updates
318 }
319
320 pub async fn subscribe_async<F>(&self, object: &str, topic: &str, callback: F)
325 where
326 F: Fn(Value) + Send + Sync + 'static,
327 {
328 let (tx, mut rx) = mpsc::unbounded_channel::<Value>();
329 let callback = Arc::new(callback);
330
331 let _ = self.tx.send(ClientMsg::Subscribe {
333 object_name: object.into(),
334 topic: topic.into(),
335 updates: tx,
336 });
337
338 tokio::spawn(async move {
340 while let Some(msg) = rx.recv().await {
341 let cb = callback.clone();
342 cb(msg);
344 }
345 });
346 }
347
348 pub async fn wait_for_object(&self, object: &str) -> std::io::Result<()> {
352 loop {
353 if self.has_object(object).await? {
354 return Ok(());
355 }
356 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
357 }
358 }
359
360 async fn has_object(&self, object: &str) -> std::io::Result<bool> {
364 let req = RpcRequest::HasObject {
365 object_name: object.into(),
366 };
367
368 let (resp_tx, resp_rx) = oneshot::channel();
369 let msg = ClientMsg::Request { req, resp_tx };
370
371 self.tx
373 .send(msg)
374 .map_err(|_| std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Actor dropped"))?;
375
376 match resp_rx.await.unwrap_or_else(|_| {
378 Err(std::io::Error::new(
379 std::io::ErrorKind::ConnectionAborted,
380 "Actor task ended",
381 ))
382 })? {
383 RpcResponse::HasObjectResult { exists, .. } => Ok(exists),
384 _ => Err(std::io::Error::new(
385 std::io::ErrorKind::InvalidData,
386 "Unexpected response type",
387 )),
388 }
389 }
390}