cs_mwc_web3/transports/
ipc.rs

1//! IPC transport
2
3use crate::{api::SubscriptionId, helpers, BatchTransport, DuplexTransport, Error, RequestId, Result, Transport};
4use futures::future::{join_all, JoinAll};
5use jsonrpc_core as rpc;
6use std::{
7    collections::BTreeMap,
8    path::Path,
9    pin::Pin,
10    sync::{atomic::AtomicUsize, Arc},
11    task::{Context, Poll},
12};
13use tokio::{
14    io::{reader_stream, AsyncWriteExt},
15    net::UnixStream,
16    stream::StreamExt,
17    sync::{mpsc, oneshot},
18};
19
20/// Unix Domain Sockets (IPC) transport.
21#[derive(Debug, Clone)]
22pub struct Ipc {
23    id: Arc<AtomicUsize>,
24    messages_tx: mpsc::UnboundedSender<TransportMessage>,
25}
26
27#[cfg(unix)]
28impl Ipc {
29    /// Creates a new IPC transport from a given path.
30    ///
31    /// IPC is only available on Unix.
32    pub async fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
33        let stream = UnixStream::connect(path).await?;
34
35        Ok(Self::with_stream(stream))
36    }
37
38    fn with_stream(stream: UnixStream) -> Self {
39        let id = Arc::new(AtomicUsize::new(1));
40        let (messages_tx, messages_rx) = mpsc::unbounded_channel();
41
42        tokio::spawn(run_server(stream, messages_rx));
43
44        Ipc { id, messages_tx }
45    }
46}
47
48impl Transport for Ipc {
49    type Out = SingleResponse;
50
51    fn prepare(&self, method: &str, params: Vec<rpc::Value>) -> (crate::RequestId, rpc::Call) {
52        let id = self.id.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
53        let request = helpers::build_request(id, method, params);
54        (id, request)
55    }
56
57    fn send(&self, id: RequestId, call: rpc::Call) -> Self::Out {
58        let (response_tx, response_rx) = oneshot::channel();
59        let message = TransportMessage::Single((id, call, response_tx));
60
61        SingleResponse(self.messages_tx.send(message).map(|()| response_rx).map_err(Into::into))
62    }
63}
64
65impl BatchTransport for Ipc {
66    type Batch = BatchResponse;
67
68    fn send_batch<T: IntoIterator<Item = (RequestId, rpc::Call)>>(&self, requests: T) -> Self::Batch {
69        let mut response_rxs = vec![];
70
71        let message = TransportMessage::Batch(
72            requests
73                .into_iter()
74                .map(|(id, call)| {
75                    let (response_tx, response_rx) = oneshot::channel();
76                    response_rxs.push(response_rx);
77
78                    (id, call, response_tx)
79                })
80                .collect(),
81        );
82
83        BatchResponse(
84            self.messages_tx
85                .send(message)
86                .map(|()| join_all(response_rxs))
87                .map_err(Into::into),
88        )
89    }
90}
91
92impl DuplexTransport for Ipc {
93    type NotificationStream = mpsc::UnboundedReceiver<rpc::Value>;
94
95    fn subscribe(&self, id: SubscriptionId) -> Result<Self::NotificationStream> {
96        let (tx, rx) = mpsc::unbounded_channel();
97        self.messages_tx.send(TransportMessage::Subscribe(id, tx))?;
98        Ok(rx)
99    }
100
101    fn unsubscribe(&self, id: SubscriptionId) -> Result<()> {
102        self.messages_tx
103            .send(TransportMessage::Unsubscribe(id))
104            .map_err(Into::into)
105    }
106}
107
108/// A future representing a pending RPC request. Resolves to a JSON RPC value.
109pub struct SingleResponse(Result<oneshot::Receiver<rpc::Value>>);
110
111impl futures::Future for SingleResponse {
112    type Output = Result<rpc::Value>;
113    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
114        match &mut self.0 {
115            Err(err) => Poll::Ready(Err(err.clone())),
116            Ok(ref mut rx) => {
117                let value = ready!(futures::Future::poll(Pin::new(rx), cx))?;
118                Poll::Ready(Ok(value))
119            }
120        }
121    }
122}
123
124/// A future representing a pending batch RPC request. Resolves to a vector of JSON RPC value.
125pub struct BatchResponse(Result<JoinAll<oneshot::Receiver<rpc::Value>>>);
126
127impl futures::Future for BatchResponse {
128    type Output = Result<Vec<Result<rpc::Value>>>;
129    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
130        match &mut self.0 {
131            Err(err) => Poll::Ready(Err(err.clone())),
132            Ok(ref mut rxs) => {
133                let poll = futures::Future::poll(Pin::new(rxs), cx);
134                let values = ready!(poll).into_iter().map(|r| r.map_err(Into::into)).collect();
135
136                Poll::Ready(Ok(values))
137            }
138        }
139    }
140}
141
142type TransportRequest = (RequestId, rpc::Call, oneshot::Sender<rpc::Value>);
143
144#[derive(Debug)]
145enum TransportMessage {
146    Single(TransportRequest),
147    Batch(Vec<TransportRequest>),
148    Subscribe(SubscriptionId, mpsc::UnboundedSender<rpc::Value>),
149    Unsubscribe(SubscriptionId),
150}
151
152#[cfg(unix)]
153async fn run_server(mut unix_stream: UnixStream, messages_rx: mpsc::UnboundedReceiver<TransportMessage>) -> Result<()> {
154    let (socket_reader, mut socket_writer) = unix_stream.split();
155    let mut pending_response_txs = BTreeMap::default();
156    let mut subscription_txs = BTreeMap::default();
157
158    let mut socket_reader = reader_stream(socket_reader);
159    let mut messages_rx = messages_rx.fuse();
160    let mut read_buffer = vec![];
161    let mut closed = false;
162
163    while !closed || pending_response_txs.len() > 0 {
164        tokio::select! {
165            message = messages_rx.next() => match message {
166                None => closed = true,
167                Some(TransportMessage::Subscribe(id, tx)) => {
168                    if let Some(_) = subscription_txs.insert(id.clone(), tx) {
169                        log::warn!("Replacing a subscription with id {:?}", id);
170                    }
171                },
172                Some(TransportMessage::Unsubscribe(id)) => {
173                    if let None = subscription_txs.remove(&id) {
174                        log::warn!("Unsubscribing not subscribed id {:?}", id);
175                    }
176                },
177                Some(TransportMessage::Single((request_id, rpc_call, response_tx))) => {
178                    if pending_response_txs.insert(request_id, response_tx).is_some() {
179                        log::warn!("Replacing a pending request with id {:?}", request_id);
180                    }
181
182                    let bytes = helpers::to_string(&rpc::Request::Single(rpc_call)).into_bytes();
183                    if let Err(err) = socket_writer.write(&bytes).await {
184                        pending_response_txs.remove(&request_id);
185                        log::error!("IPC write error: {:?}", err);
186                    }
187                }
188                Some(TransportMessage::Batch(requests)) => {
189                    let mut request_ids = vec![];
190                    let mut rpc_calls = vec![];
191
192                    for (request_id, rpc_call, response_tx) in requests {
193                        request_ids.push(request_id);
194                        rpc_calls.push(rpc_call);
195
196                        if pending_response_txs.insert(request_id, response_tx).is_some() {
197                            log::warn!("Replacing a pending request with id {:?}", request_id);
198                        }
199                    }
200
201                    let bytes = helpers::to_string(&rpc::Request::Batch(rpc_calls)).into_bytes();
202
203                    if let Err(err) = socket_writer.write(&bytes).await {
204                        log::error!("IPC write error: {:?}", err);
205                        for request_id in request_ids {
206                            pending_response_txs.remove(&request_id);
207                        }
208                    }
209                }
210            },
211            bytes = socket_reader.next() => match bytes {
212                Some(Ok(bytes)) => {
213                    read_buffer.extend_from_slice(&bytes);
214
215                    let read_len = {
216                        let mut de: serde_json::StreamDeserializer<_, serde_json::Value> =
217                            serde_json::Deserializer::from_slice(&read_buffer).into_iter();
218
219                        while let Some(Ok(value)) = de.next() {
220                            if let Ok(notification) = serde_json::from_value::<rpc::Notification>(value.clone()) {
221                                let _ = notify(&mut subscription_txs, notification);
222                                continue;
223                            }
224
225                            if let Ok(response) = serde_json::from_value::<rpc::Response>(value) {
226                                let _ = respond(&mut pending_response_txs, response);
227                                continue;
228                            }
229
230                            log::warn!("JSON is not a response or notification");
231                        }
232
233                        de.byte_offset()
234                    };
235
236                    read_buffer.copy_within(read_len.., 0);
237                    read_buffer.truncate(read_buffer.len() - read_len);
238                },
239                Some(Err(err)) => {
240                    log::error!("IPC read error: {:?}", err);
241                    return Err(err.into());
242                },
243                None => break,
244            }
245        };
246    }
247
248    Ok(())
249}
250
251fn notify(
252    subscription_txs: &mut BTreeMap<SubscriptionId, mpsc::UnboundedSender<rpc::Value>>,
253    notification: rpc::Notification,
254) -> std::result::Result<(), ()> {
255    if let rpc::Params::Map(params) = notification.params {
256        let id = params.get("subscription");
257        let result = params.get("result");
258
259        if let (Some(&rpc::Value::String(ref id)), Some(result)) = (id, result) {
260            let id: SubscriptionId = id.clone().into();
261            if let Some(tx) = subscription_txs.get(&id) {
262                if let Err(e) = tx.send(result.clone()) {
263                    log::error!("Error sending notification: {:?} (id: {:?}", e, id);
264                }
265            } else {
266                log::warn!("Got notification for unknown subscription (id: {:?})", id);
267            }
268        } else {
269            log::error!("Got unsupported notification (id: {:?})", id);
270        }
271    }
272
273    Ok(())
274}
275
276fn respond(
277    pending_response_txs: &mut BTreeMap<RequestId, oneshot::Sender<rpc::Value>>,
278    response: rpc::Response,
279) -> std::result::Result<(), ()> {
280    let outputs = match response {
281        rpc::Response::Single(output) => vec![output],
282        rpc::Response::Batch(outputs) => outputs,
283    };
284
285    for output in outputs {
286        let _ = respond_output(pending_response_txs, output);
287    }
288
289    Ok(())
290}
291
292fn respond_output(
293    pending_response_txs: &mut BTreeMap<RequestId, oneshot::Sender<rpc::Value>>,
294    output: rpc::Output,
295) -> std::result::Result<(), ()> {
296    let id = output.id().clone();
297
298    let value = helpers::to_result_from_output(output).map_err(|err| {
299        log::warn!("Unable to parse output into rpc::Value: {:?}", err);
300    })?;
301
302    let id = match id {
303        rpc::Id::Num(num) => num as usize,
304        _ => {
305            log::warn!("Got unsupported response (id: {:?})", id);
306            return Err(());
307        }
308    };
309
310    let response_tx = pending_response_txs.remove(&id).ok_or_else(|| {
311        log::warn!("Got response for unknown request (id: {:?})", id);
312    })?;
313
314    response_tx.send(value).map_err(|err| {
315        log::warn!("Sending a response to deallocated channel: {:?}", err);
316    })
317}
318
319impl From<mpsc::error::SendError<TransportMessage>> for Error {
320    fn from(err: mpsc::error::SendError<TransportMessage>) -> Self {
321        Error::Transport(format!("Send Error: {:?}", err))
322    }
323}
324
325impl From<oneshot::error::RecvError> for Error {
326    fn from(err: oneshot::error::RecvError) -> Self {
327        Error::Transport(format!("Recv Error: {:?}", err))
328    }
329}
330
331#[cfg(all(test, unix))]
332mod test {
333    use super::*;
334    use serde_json::json;
335    use tokio::{
336        io::{reader_stream, AsyncWriteExt},
337        net::UnixStream,
338    };
339
340    #[tokio::test]
341    async fn works_for_single_requests() {
342        let (stream1, stream2) = UnixStream::pair().unwrap();
343        let ipc = Ipc::with_stream(stream1);
344
345        tokio::spawn(eth_node_single(stream2));
346
347        let (req_id, request) = ipc.prepare(
348            "eth_test",
349            vec![json!({
350                "test": -1,
351            })],
352        );
353        let response = ipc.send(req_id, request).await;
354        let expected_response_json: serde_json::Value = json!({
355            "test": 1,
356        });
357        assert_eq!(response, Ok(expected_response_json));
358
359        let (req_id, request) = ipc.prepare(
360            "eth_test",
361            vec![json!({
362                "test": 3,
363            })],
364        );
365        let response = ipc.send(req_id, request).await;
366        let expected_response_json: serde_json::Value = json!({
367            "test": "string1",
368        });
369        assert_eq!(response, Ok(expected_response_json));
370    }
371
372    async fn eth_node_single(stream: UnixStream) {
373        let (rx, mut tx) = stream.into_split();
374
375        let mut rx = reader_stream(rx);
376        if let Some(Ok(bytes)) = rx.next().await {
377            let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
378
379            assert_eq!(
380                v,
381                json!({
382                    "jsonrpc": "2.0",
383                    "method": "eth_test",
384                    "id": 1,
385                    "params": [{
386                        "test": -1
387                    }]
388                })
389            );
390
391            tx.write(r#"{"jsonrpc": "2.0", "id": 1, "result": {"test": 1}}"#.as_ref())
392                .await
393                .unwrap();
394            tx.flush().await.unwrap();
395        }
396
397        if let Some(Ok(bytes)) = rx.next().await {
398            let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
399
400            assert_eq!(
401                v,
402                json!({
403                    "jsonrpc": "2.0",
404                    "method": "eth_test",
405                    "id": 2,
406                    "params": [{
407                        "test": 3
408                    }]
409                })
410            );
411
412            let response_bytes = r#"{"jsonrpc": "2.0", "id": 2, "result": {"test": "string1"}}"#;
413            for chunk in response_bytes.as_bytes().chunks(3) {
414                tx.write(chunk).await.unwrap();
415                tx.flush().await.unwrap();
416            }
417        }
418    }
419
420    #[tokio::test]
421    async fn works_for_batch_request() {
422        let (stream1, stream2) = UnixStream::pair().unwrap();
423        let ipc = Ipc::with_stream(stream1);
424
425        tokio::spawn(eth_node_batch(stream2));
426
427        let requests = vec![json!({"test": -1,}), json!({"test": 3,})];
428        let requests = requests.into_iter().map(|v| ipc.prepare("eth_test", vec![v]));
429
430        let response = ipc.send_batch(requests).await;
431        let expected_response_json = vec![Ok(json!({"test": 1})), Ok(json!({"test": "string1"}))];
432
433        assert_eq!(response, Ok(expected_response_json));
434    }
435
436    async fn eth_node_batch(stream: UnixStream) {
437        let (rx, mut tx) = stream.into_split();
438
439        let mut rx = reader_stream(rx);
440        if let Some(Ok(bytes)) = rx.next().await {
441            let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
442
443            assert_eq!(
444                v,
445                json!([{
446                    "jsonrpc": "2.0",
447                    "method": "eth_test",
448                    "id": 1,
449                    "params": [{
450                        "test": -1
451                    }]
452                }, {
453                    "jsonrpc": "2.0",
454                    "method": "eth_test",
455                    "id": 2,
456                    "params": [{
457                        "test": 3
458                    }]
459                }])
460            );
461
462            let response = json!([
463                {"jsonrpc": "2.0", "id": 1, "result": {"test": 1}},
464                {"jsonrpc": "2.0", "id": 2, "result": {"test": "string1"}},
465            ]);
466
467            tx.write_all(serde_json::to_string(&response).unwrap().as_ref())
468                .await
469                .unwrap();
470
471            tx.flush().await.unwrap();
472        }
473    }
474
475    #[tokio::test]
476    async fn works_for_partial_batches() {
477        let (stream1, stream2) = UnixStream::pair().unwrap();
478        let ipc = Ipc::with_stream(stream1);
479
480        tokio::spawn(eth_node_partial_batches(stream2));
481
482        let requests = vec![json!({"test": 0}), json!({"test": 1}), json!({"test": 2})];
483        let requests = requests.into_iter().map(|v| ipc.execute("eth_test", vec![v]));
484        let responses = join_all(requests).await;
485
486        assert_eq!(responses[0], Ok(json!({"test": 0})));
487        assert_eq!(responses[2], Ok(json!({"test": 2})));
488        assert!(responses[1].is_err());
489    }
490
491    async fn eth_node_partial_batches(stream: UnixStream) {
492        let (rx, mut tx) = stream.into_split();
493        let mut buf = vec![];
494        let mut rx = reader_stream(rx);
495        while let Some(Ok(bytes)) = rx.next().await {
496            buf.extend(bytes);
497
498            let requests: std::result::Result<Vec<serde_json::Value>, serde_json::Error> =
499                serde_json::Deserializer::from_slice(&buf).into_iter().collect();
500
501            if let Ok(requests) = requests {
502                if requests.len() == 3 {
503                    break;
504                }
505            }
506        }
507
508        let response = json!([
509            {"jsonrpc": "2.0", "id": 1, "result": {"test": 0}},
510            {"jsonrpc": "2.0", "id": "2", "result": {"test": 2}},
511            {"jsonrpc": "2.0", "id": 3, "result": {"test": 2}},
512        ]);
513
514        tx.write_all(serde_json::to_string(&response).unwrap().as_ref())
515            .await
516            .unwrap();
517
518        tx.flush().await.unwrap();
519    }
520}