1use crate::{api::SubscriptionId, helpers, BatchTransport, DuplexTransport, Error, RequestId, Result, Transport};
4use futures::future::{join_all, JoinAll};
5use jsonrpc_core as rpc;
6use std::{
7 collections::BTreeMap,
8 path::Path,
9 pin::Pin,
10 sync::{atomic::AtomicUsize, Arc},
11 task::{Context, Poll},
12};
13use tokio::{
14 io::{reader_stream, AsyncWriteExt},
15 net::UnixStream,
16 stream::StreamExt,
17 sync::{mpsc, oneshot},
18};
19
20#[derive(Debug, Clone)]
22pub struct Ipc {
23 id: Arc<AtomicUsize>,
24 messages_tx: mpsc::UnboundedSender<TransportMessage>,
25}
26
27#[cfg(unix)]
28impl Ipc {
29 pub async fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
33 let stream = UnixStream::connect(path).await?;
34
35 Ok(Self::with_stream(stream))
36 }
37
38 fn with_stream(stream: UnixStream) -> Self {
39 let id = Arc::new(AtomicUsize::new(1));
40 let (messages_tx, messages_rx) = mpsc::unbounded_channel();
41
42 tokio::spawn(run_server(stream, messages_rx));
43
44 Ipc { id, messages_tx }
45 }
46}
47
48impl Transport for Ipc {
49 type Out = SingleResponse;
50
51 fn prepare(&self, method: &str, params: Vec<rpc::Value>) -> (crate::RequestId, rpc::Call) {
52 let id = self.id.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
53 let request = helpers::build_request(id, method, params);
54 (id, request)
55 }
56
57 fn send(&self, id: RequestId, call: rpc::Call) -> Self::Out {
58 let (response_tx, response_rx) = oneshot::channel();
59 let message = TransportMessage::Single((id, call, response_tx));
60
61 SingleResponse(self.messages_tx.send(message).map(|()| response_rx).map_err(Into::into))
62 }
63}
64
65impl BatchTransport for Ipc {
66 type Batch = BatchResponse;
67
68 fn send_batch<T: IntoIterator<Item = (RequestId, rpc::Call)>>(&self, requests: T) -> Self::Batch {
69 let mut response_rxs = vec![];
70
71 let message = TransportMessage::Batch(
72 requests
73 .into_iter()
74 .map(|(id, call)| {
75 let (response_tx, response_rx) = oneshot::channel();
76 response_rxs.push(response_rx);
77
78 (id, call, response_tx)
79 })
80 .collect(),
81 );
82
83 BatchResponse(
84 self.messages_tx
85 .send(message)
86 .map(|()| join_all(response_rxs))
87 .map_err(Into::into),
88 )
89 }
90}
91
92impl DuplexTransport for Ipc {
93 type NotificationStream = mpsc::UnboundedReceiver<rpc::Value>;
94
95 fn subscribe(&self, id: SubscriptionId) -> Result<Self::NotificationStream> {
96 let (tx, rx) = mpsc::unbounded_channel();
97 self.messages_tx.send(TransportMessage::Subscribe(id, tx))?;
98 Ok(rx)
99 }
100
101 fn unsubscribe(&self, id: SubscriptionId) -> Result<()> {
102 self.messages_tx
103 .send(TransportMessage::Unsubscribe(id))
104 .map_err(Into::into)
105 }
106}
107
108pub struct SingleResponse(Result<oneshot::Receiver<rpc::Value>>);
110
111impl futures::Future for SingleResponse {
112 type Output = Result<rpc::Value>;
113 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
114 match &mut self.0 {
115 Err(err) => Poll::Ready(Err(err.clone())),
116 Ok(ref mut rx) => {
117 let value = ready!(futures::Future::poll(Pin::new(rx), cx))?;
118 Poll::Ready(Ok(value))
119 }
120 }
121 }
122}
123
124pub struct BatchResponse(Result<JoinAll<oneshot::Receiver<rpc::Value>>>);
126
127impl futures::Future for BatchResponse {
128 type Output = Result<Vec<Result<rpc::Value>>>;
129 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
130 match &mut self.0 {
131 Err(err) => Poll::Ready(Err(err.clone())),
132 Ok(ref mut rxs) => {
133 let poll = futures::Future::poll(Pin::new(rxs), cx);
134 let values = ready!(poll).into_iter().map(|r| r.map_err(Into::into)).collect();
135
136 Poll::Ready(Ok(values))
137 }
138 }
139 }
140}
141
142type TransportRequest = (RequestId, rpc::Call, oneshot::Sender<rpc::Value>);
143
144#[derive(Debug)]
145enum TransportMessage {
146 Single(TransportRequest),
147 Batch(Vec<TransportRequest>),
148 Subscribe(SubscriptionId, mpsc::UnboundedSender<rpc::Value>),
149 Unsubscribe(SubscriptionId),
150}
151
152#[cfg(unix)]
153async fn run_server(mut unix_stream: UnixStream, messages_rx: mpsc::UnboundedReceiver<TransportMessage>) -> Result<()> {
154 let (socket_reader, mut socket_writer) = unix_stream.split();
155 let mut pending_response_txs = BTreeMap::default();
156 let mut subscription_txs = BTreeMap::default();
157
158 let mut socket_reader = reader_stream(socket_reader);
159 let mut messages_rx = messages_rx.fuse();
160 let mut read_buffer = vec![];
161 let mut closed = false;
162
163 while !closed || pending_response_txs.len() > 0 {
164 tokio::select! {
165 message = messages_rx.next() => match message {
166 None => closed = true,
167 Some(TransportMessage::Subscribe(id, tx)) => {
168 if let Some(_) = subscription_txs.insert(id.clone(), tx) {
169 log::warn!("Replacing a subscription with id {:?}", id);
170 }
171 },
172 Some(TransportMessage::Unsubscribe(id)) => {
173 if let None = subscription_txs.remove(&id) {
174 log::warn!("Unsubscribing not subscribed id {:?}", id);
175 }
176 },
177 Some(TransportMessage::Single((request_id, rpc_call, response_tx))) => {
178 if pending_response_txs.insert(request_id, response_tx).is_some() {
179 log::warn!("Replacing a pending request with id {:?}", request_id);
180 }
181
182 let bytes = helpers::to_string(&rpc::Request::Single(rpc_call)).into_bytes();
183 if let Err(err) = socket_writer.write(&bytes).await {
184 pending_response_txs.remove(&request_id);
185 log::error!("IPC write error: {:?}", err);
186 }
187 }
188 Some(TransportMessage::Batch(requests)) => {
189 let mut request_ids = vec![];
190 let mut rpc_calls = vec![];
191
192 for (request_id, rpc_call, response_tx) in requests {
193 request_ids.push(request_id);
194 rpc_calls.push(rpc_call);
195
196 if pending_response_txs.insert(request_id, response_tx).is_some() {
197 log::warn!("Replacing a pending request with id {:?}", request_id);
198 }
199 }
200
201 let bytes = helpers::to_string(&rpc::Request::Batch(rpc_calls)).into_bytes();
202
203 if let Err(err) = socket_writer.write(&bytes).await {
204 log::error!("IPC write error: {:?}", err);
205 for request_id in request_ids {
206 pending_response_txs.remove(&request_id);
207 }
208 }
209 }
210 },
211 bytes = socket_reader.next() => match bytes {
212 Some(Ok(bytes)) => {
213 read_buffer.extend_from_slice(&bytes);
214
215 let read_len = {
216 let mut de: serde_json::StreamDeserializer<_, serde_json::Value> =
217 serde_json::Deserializer::from_slice(&read_buffer).into_iter();
218
219 while let Some(Ok(value)) = de.next() {
220 if let Ok(notification) = serde_json::from_value::<rpc::Notification>(value.clone()) {
221 let _ = notify(&mut subscription_txs, notification);
222 continue;
223 }
224
225 if let Ok(response) = serde_json::from_value::<rpc::Response>(value) {
226 let _ = respond(&mut pending_response_txs, response);
227 continue;
228 }
229
230 log::warn!("JSON is not a response or notification");
231 }
232
233 de.byte_offset()
234 };
235
236 read_buffer.copy_within(read_len.., 0);
237 read_buffer.truncate(read_buffer.len() - read_len);
238 },
239 Some(Err(err)) => {
240 log::error!("IPC read error: {:?}", err);
241 return Err(err.into());
242 },
243 None => break,
244 }
245 };
246 }
247
248 Ok(())
249}
250
251fn notify(
252 subscription_txs: &mut BTreeMap<SubscriptionId, mpsc::UnboundedSender<rpc::Value>>,
253 notification: rpc::Notification,
254) -> std::result::Result<(), ()> {
255 if let rpc::Params::Map(params) = notification.params {
256 let id = params.get("subscription");
257 let result = params.get("result");
258
259 if let (Some(&rpc::Value::String(ref id)), Some(result)) = (id, result) {
260 let id: SubscriptionId = id.clone().into();
261 if let Some(tx) = subscription_txs.get(&id) {
262 if let Err(e) = tx.send(result.clone()) {
263 log::error!("Error sending notification: {:?} (id: {:?}", e, id);
264 }
265 } else {
266 log::warn!("Got notification for unknown subscription (id: {:?})", id);
267 }
268 } else {
269 log::error!("Got unsupported notification (id: {:?})", id);
270 }
271 }
272
273 Ok(())
274}
275
276fn respond(
277 pending_response_txs: &mut BTreeMap<RequestId, oneshot::Sender<rpc::Value>>,
278 response: rpc::Response,
279) -> std::result::Result<(), ()> {
280 let outputs = match response {
281 rpc::Response::Single(output) => vec![output],
282 rpc::Response::Batch(outputs) => outputs,
283 };
284
285 for output in outputs {
286 let _ = respond_output(pending_response_txs, output);
287 }
288
289 Ok(())
290}
291
292fn respond_output(
293 pending_response_txs: &mut BTreeMap<RequestId, oneshot::Sender<rpc::Value>>,
294 output: rpc::Output,
295) -> std::result::Result<(), ()> {
296 let id = output.id().clone();
297
298 let value = helpers::to_result_from_output(output).map_err(|err| {
299 log::warn!("Unable to parse output into rpc::Value: {:?}", err);
300 })?;
301
302 let id = match id {
303 rpc::Id::Num(num) => num as usize,
304 _ => {
305 log::warn!("Got unsupported response (id: {:?})", id);
306 return Err(());
307 }
308 };
309
310 let response_tx = pending_response_txs.remove(&id).ok_or_else(|| {
311 log::warn!("Got response for unknown request (id: {:?})", id);
312 })?;
313
314 response_tx.send(value).map_err(|err| {
315 log::warn!("Sending a response to deallocated channel: {:?}", err);
316 })
317}
318
319impl From<mpsc::error::SendError<TransportMessage>> for Error {
320 fn from(err: mpsc::error::SendError<TransportMessage>) -> Self {
321 Error::Transport(format!("Send Error: {:?}", err))
322 }
323}
324
325impl From<oneshot::error::RecvError> for Error {
326 fn from(err: oneshot::error::RecvError) -> Self {
327 Error::Transport(format!("Recv Error: {:?}", err))
328 }
329}
330
331#[cfg(all(test, unix))]
332mod test {
333 use super::*;
334 use serde_json::json;
335 use tokio::{
336 io::{reader_stream, AsyncWriteExt},
337 net::UnixStream,
338 };
339
340 #[tokio::test]
341 async fn works_for_single_requests() {
342 let (stream1, stream2) = UnixStream::pair().unwrap();
343 let ipc = Ipc::with_stream(stream1);
344
345 tokio::spawn(eth_node_single(stream2));
346
347 let (req_id, request) = ipc.prepare(
348 "eth_test",
349 vec![json!({
350 "test": -1,
351 })],
352 );
353 let response = ipc.send(req_id, request).await;
354 let expected_response_json: serde_json::Value = json!({
355 "test": 1,
356 });
357 assert_eq!(response, Ok(expected_response_json));
358
359 let (req_id, request) = ipc.prepare(
360 "eth_test",
361 vec![json!({
362 "test": 3,
363 })],
364 );
365 let response = ipc.send(req_id, request).await;
366 let expected_response_json: serde_json::Value = json!({
367 "test": "string1",
368 });
369 assert_eq!(response, Ok(expected_response_json));
370 }
371
372 async fn eth_node_single(stream: UnixStream) {
373 let (rx, mut tx) = stream.into_split();
374
375 let mut rx = reader_stream(rx);
376 if let Some(Ok(bytes)) = rx.next().await {
377 let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
378
379 assert_eq!(
380 v,
381 json!({
382 "jsonrpc": "2.0",
383 "method": "eth_test",
384 "id": 1,
385 "params": [{
386 "test": -1
387 }]
388 })
389 );
390
391 tx.write(r#"{"jsonrpc": "2.0", "id": 1, "result": {"test": 1}}"#.as_ref())
392 .await
393 .unwrap();
394 tx.flush().await.unwrap();
395 }
396
397 if let Some(Ok(bytes)) = rx.next().await {
398 let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
399
400 assert_eq!(
401 v,
402 json!({
403 "jsonrpc": "2.0",
404 "method": "eth_test",
405 "id": 2,
406 "params": [{
407 "test": 3
408 }]
409 })
410 );
411
412 let response_bytes = r#"{"jsonrpc": "2.0", "id": 2, "result": {"test": "string1"}}"#;
413 for chunk in response_bytes.as_bytes().chunks(3) {
414 tx.write(chunk).await.unwrap();
415 tx.flush().await.unwrap();
416 }
417 }
418 }
419
420 #[tokio::test]
421 async fn works_for_batch_request() {
422 let (stream1, stream2) = UnixStream::pair().unwrap();
423 let ipc = Ipc::with_stream(stream1);
424
425 tokio::spawn(eth_node_batch(stream2));
426
427 let requests = vec![json!({"test": -1,}), json!({"test": 3,})];
428 let requests = requests.into_iter().map(|v| ipc.prepare("eth_test", vec![v]));
429
430 let response = ipc.send_batch(requests).await;
431 let expected_response_json = vec![Ok(json!({"test": 1})), Ok(json!({"test": "string1"}))];
432
433 assert_eq!(response, Ok(expected_response_json));
434 }
435
436 async fn eth_node_batch(stream: UnixStream) {
437 let (rx, mut tx) = stream.into_split();
438
439 let mut rx = reader_stream(rx);
440 if let Some(Ok(bytes)) = rx.next().await {
441 let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
442
443 assert_eq!(
444 v,
445 json!([{
446 "jsonrpc": "2.0",
447 "method": "eth_test",
448 "id": 1,
449 "params": [{
450 "test": -1
451 }]
452 }, {
453 "jsonrpc": "2.0",
454 "method": "eth_test",
455 "id": 2,
456 "params": [{
457 "test": 3
458 }]
459 }])
460 );
461
462 let response = json!([
463 {"jsonrpc": "2.0", "id": 1, "result": {"test": 1}},
464 {"jsonrpc": "2.0", "id": 2, "result": {"test": "string1"}},
465 ]);
466
467 tx.write_all(serde_json::to_string(&response).unwrap().as_ref())
468 .await
469 .unwrap();
470
471 tx.flush().await.unwrap();
472 }
473 }
474
475 #[tokio::test]
476 async fn works_for_partial_batches() {
477 let (stream1, stream2) = UnixStream::pair().unwrap();
478 let ipc = Ipc::with_stream(stream1);
479
480 tokio::spawn(eth_node_partial_batches(stream2));
481
482 let requests = vec![json!({"test": 0}), json!({"test": 1}), json!({"test": 2})];
483 let requests = requests.into_iter().map(|v| ipc.execute("eth_test", vec![v]));
484 let responses = join_all(requests).await;
485
486 assert_eq!(responses[0], Ok(json!({"test": 0})));
487 assert_eq!(responses[2], Ok(json!({"test": 2})));
488 assert!(responses[1].is_err());
489 }
490
491 async fn eth_node_partial_batches(stream: UnixStream) {
492 let (rx, mut tx) = stream.into_split();
493 let mut buf = vec![];
494 let mut rx = reader_stream(rx);
495 while let Some(Ok(bytes)) = rx.next().await {
496 buf.extend(bytes);
497
498 let requests: std::result::Result<Vec<serde_json::Value>, serde_json::Error> =
499 serde_json::Deserializer::from_slice(&buf).into_iter().collect();
500
501 if let Ok(requests) = requests {
502 if requests.len() == 3 {
503 break;
504 }
505 }
506 }
507
508 let response = json!([
509 {"jsonrpc": "2.0", "id": 1, "result": {"test": 0}},
510 {"jsonrpc": "2.0", "id": "2", "result": {"test": 2}},
511 {"jsonrpc": "2.0", "id": 3, "result": {"test": 2}},
512 ]);
513
514 tx.write_all(serde_json::to_string(&response).unwrap().as_ref())
515 .await
516 .unwrap();
517
518 tx.flush().await.unwrap();
519 }
520}