1#![cfg(all(unix, not(target_os = "emscripten"), feature = "tokio"))]
35#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
36
37use std::{
38 collections::VecDeque,
39 io,
40 path::PathBuf,
41 pin::Pin,
42 task::{Context, Poll},
43};
44
45use futures::{
46 future::{BoxFuture, Ready},
47 prelude::*,
48 stream::BoxStream,
49};
50use libp2p_core::{
51 multiaddr::{Multiaddr, Protocol},
52 transport::{DialOpts, ListenerId, TransportError, TransportEvent},
53 Transport,
54};
55
56pub type Listener<T> = BoxStream<
57 'static,
58 Result<
59 TransportEvent<<T as Transport>::ListenerUpgrade, <T as Transport>::Error>,
60 Result<(), <T as Transport>::Error>,
61 >,
62>;
63
64macro_rules! codegen {
65 ($feature_name:expr, $uds_config:ident, $build_listener:expr, $unix_stream:ty, $($mut_or_not:tt)*) => {
66 pub struct $uds_config {
68 listeners: VecDeque<(ListenerId, Listener<Self>)>,
69 }
70
71 impl $uds_config {
72 pub fn new() -> $uds_config {
74 $uds_config {
75 listeners: VecDeque::new(),
76 }
77 }
78 }
79
80 impl Default for $uds_config {
81 fn default() -> Self {
82 Self::new()
83 }
84 }
85
86 impl Transport for $uds_config {
87 type Output = $unix_stream;
88 type Error = io::Error;
89 type ListenerUpgrade = Ready<Result<Self::Output, Self::Error>>;
90 type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;
91
92 fn listen_on(
93 &mut self,
94 id: ListenerId,
95 addr: Multiaddr,
96 ) -> Result<(), TransportError<Self::Error>> {
97 if let Ok(path) = multiaddr_to_path(&addr) {
98 #[allow(clippy::redundant_closure_call)]
99 let listener = $build_listener(path)
100 .map_err(Err)
101 .map_ok(move |listener| {
102 stream::once({
103 let addr = addr.clone();
104 async move {
105 tracing::debug!(address=%addr, "Now listening on address");
106 Ok(TransportEvent::NewAddress {
107 listener_id: id,
108 listen_addr: addr,
109 })
110 }
111 })
112 .chain(stream::unfold(
113 listener,
114 move |listener| {
115 let addr = addr.clone();
116 async move {
117 let event = match listener.accept().await {
118 Ok((stream, _)) => {
119 tracing::debug!(address=%addr, "incoming connection on address");
120 TransportEvent::Incoming {
121 upgrade: future::ok(stream),
122 local_addr: addr.clone(),
123 send_back_addr: addr.clone(),
124 listener_id: id,
125 }
126 }
127 Err(error) => TransportEvent::ListenerError {
128 listener_id: id,
129 error,
130 },
131 };
132 Some((Ok(event), listener))
133 }
134 },
135 ))
136 })
137 .try_flatten_stream()
138 .boxed();
139 self.listeners.push_back((id, listener));
140 Ok(())
141 } else {
142 Err(TransportError::MultiaddrNotSupported(addr))
143 }
144 }
145
146 fn remove_listener(&mut self, id: ListenerId) -> bool {
147 if let Some(index) = self
148 .listeners
149 .iter()
150 .position(|(listener_id, _)| listener_id == &id)
151 {
152 let listener_stream = self.listeners.get_mut(index).unwrap();
153 let report_closed_stream = stream::once(async { Err(Ok(())) }).boxed();
154 *listener_stream = (id, report_closed_stream);
155 true
156 } else {
157 false
158 }
159 }
160
161 fn dial(&mut self, addr: Multiaddr, _dial_opts: DialOpts) -> Result<Self::Dial, TransportError<Self::Error>> {
162 if let Ok(path) = multiaddr_to_path(&addr) {
164 tracing::debug!(address=%addr, "Dialing address");
165 Ok(async move { <$unix_stream>::connect(&path).await }.boxed())
166 } else {
167 Err(TransportError::MultiaddrNotSupported(addr))
168 }
169 }
170
171 fn poll(
172 mut self: Pin<&mut Self>,
173 cx: &mut Context<'_>,
174 ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
175 let mut remaining = self.listeners.len();
176 while let Some((id, mut listener)) = self.listeners.pop_back() {
177 let event = match Stream::poll_next(Pin::new(&mut listener), cx) {
178 Poll::Pending => None,
179 Poll::Ready(None) => panic!("Alive listeners always have a sender."),
180 Poll::Ready(Some(Ok(event))) => Some(event),
181 Poll::Ready(Some(Err(reason))) => {
182 return Poll::Ready(TransportEvent::ListenerClosed {
183 listener_id: id,
184 reason,
185 })
186 }
187 };
188 self.listeners.push_front((id, listener));
189 if let Some(event) = event {
190 return Poll::Ready(event);
191 } else {
192 remaining -= 1;
193 if remaining == 0 {
194 break;
195 }
196 }
197 }
198 Poll::Pending
199 }
200 }
201 };
202}
203
204#[cfg(feature = "tokio")]
205codegen!(
206 "tokio",
207 TokioUdsConfig,
208 |addr| async move { tokio::net::UnixListener::bind(&addr) },
209 tokio::net::UnixStream,
210);
211
212fn multiaddr_to_path(addr: &Multiaddr) -> Result<PathBuf, ()> {
218 let mut protocols = addr.iter();
219 match protocols.next() {
220 Some(Protocol::Unix(ref path)) => {
221 let path = PathBuf::from(path.as_ref());
222 if !path.is_absolute() {
223 return Err(());
224 }
225 match protocols.next() {
226 None | Some(Protocol::P2p(_)) => Ok(path),
227 Some(_) => Err(()),
228 }
229 }
230 _ => Err(()),
231 }
232}
233
234#[cfg(all(test, feature = "tokio"))]
235mod tests {
236 use std::{borrow::Cow, path::Path};
237
238 use futures::{channel::oneshot, prelude::*};
239 use libp2p_core::{
240 multiaddr::{Multiaddr, Protocol},
241 transport::{DialOpts, ListenerId, PortUse},
242 Endpoint, Transport,
243 };
244 use tokio::io::{AsyncReadExt, AsyncWriteExt};
245
246 use super::{multiaddr_to_path, TokioUdsConfig};
247
248 #[test]
249 fn multiaddr_to_path_conversion() {
250 assert!(
251 multiaddr_to_path(&"/ip4/127.0.0.1/udp/1234".parse::<Multiaddr>().unwrap()).is_err()
252 );
253
254 assert_eq!(
255 multiaddr_to_path(&Multiaddr::from(Protocol::Unix("/tmp/foo".into()))),
256 Ok(Path::new("/tmp/foo").to_owned())
257 );
258 assert_eq!(
259 multiaddr_to_path(&Multiaddr::from(Protocol::Unix("/home/bar/baz".into()))),
260 Ok(Path::new("/home/bar/baz").to_owned())
261 );
262 }
263
264 #[tokio::test]
265 async fn communicating_between_dialer_and_listener() {
266 let temp_dir = tempfile::tempdir().unwrap();
267 let socket = temp_dir.path().join("socket");
268 let addr = Multiaddr::from(Protocol::Unix(Cow::Owned(
269 socket.to_string_lossy().into_owned(),
270 )));
271
272 let (tx, rx) = oneshot::channel();
273
274 let listener = async move {
275 let mut transport = TokioUdsConfig::new().boxed();
276 transport.listen_on(ListenerId::next(), addr).unwrap();
277
278 let listen_addr = transport
279 .select_next_some()
280 .await
281 .into_new_address()
282 .expect("listen address");
283
284 tx.send(listen_addr).unwrap();
285
286 let (sock, _addr) = transport
287 .select_next_some()
288 .await
289 .into_incoming()
290 .expect("incoming stream");
291
292 let mut sock = sock.await.unwrap();
293 let mut buf = [0u8; 3];
294 sock.read_exact(&mut buf).await.unwrap();
295 assert_eq!(buf, [1, 2, 3]);
296 };
297
298 let dialer = async move {
299 let mut uds = TokioUdsConfig::new();
300 let addr = rx.await.unwrap();
301 let mut socket = uds
302 .dial(
303 addr,
304 DialOpts {
305 role: Endpoint::Dialer,
306 port_use: PortUse::Reuse,
307 },
308 )
309 .unwrap()
310 .await
311 .unwrap();
312 let _ = socket.write(&[1, 2, 3]).await.unwrap();
313 };
314
315 tokio::join!(listener, dialer);
316 }
317
318 #[test]
319 #[ignore] fn larger_addr_denied() {
321 let mut uds = TokioUdsConfig::new();
322
323 let addr = "/unix//foo/bar".parse::<Multiaddr>().unwrap();
324 assert!(uds.listen_on(ListenerId::next(), addr).is_err());
325 }
326
327 #[test]
328 #[ignore] fn relative_addr_denied() {
330 assert!("/unix/./foo/bar".parse::<Multiaddr>().is_err());
331 }
332}