Skip to main content

libp2p_wasi_sockets/
transport.rs

1use std::collections::HashMap;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use libp2p_core::multiaddr::Multiaddr;
8use libp2p_core::transport::{DialOpts, ListenerId, TransportError, TransportEvent};
9use libp2p_core::Transport;
10use tracing::warn;
11
12use crate::error::Error;
13use crate::multiaddr::{multiaddr_to_socketaddr, socketaddr_to_multiaddr};
14use crate::stream::WasiTcpStream;
15
16/// Configuration for [`WasiTcpTransport`].
17#[derive(Debug, Clone)]
18pub struct Config {
19    /// Disable Nagle's algorithm. Defaults to `true`.
20    pub nodelay: bool,
21    /// TCP keep-alive interval. `None` disables keep-alive.
22    pub keep_alive: Option<Duration>,
23    /// Listen backlog passed to `wasi:sockets/tcp.set-listen-backlog-size`. Defaults to 128.
24    pub listen_backlog: u32,
25}
26
27impl Default for Config {
28    fn default() -> Self {
29        Self {
30            nodelay: true,
31            keep_alive: None,
32            listen_backlog: 128,
33        }
34    }
35}
36
37// ── wasm32-wasip2 implementation ─────────────────────────────────────────────
38
39/// Non-Send box future — sufficient for a single-threaded wasm32 runtime.
40#[cfg(target_arch = "wasm32")]
41type WasmBoxFut<T> = Pin<Box<dyn std::future::Future<Output = T>>>;
42
43/// State machine for a single listener identified by its [`ListenerId`].
44#[cfg(target_arch = "wasm32")]
45struct ListenerState {
46    bind_addr: std::net::SocketAddr,
47    /// The bound listener, available once the bind future resolves.
48    listener: Option<Arc<wstd::net::TcpListener>>,
49    /// In-flight bind future.
50    bind_future: Option<WasmBoxFut<std::io::Result<wstd::net::TcpListener>>>,
51    /// In-flight accept future.
52    accept_future: Option<WasmBoxFut<std::io::Result<wstd::net::TcpStream>>>,
53    /// Whether we have emitted a `NewAddress` event for this listener.
54    /// Also used as a sentinel: set back to `false` after emitting `AddressExpired`
55    /// so the next `poll` knows to emit `ListenerClosed`.
56    announced: bool,
57    /// Set by `remove_listener`; causes `poll` to emit `AddressExpired` (if
58    /// `announced`) followed by `ListenerClosed`, then drop the entry.
59    closing: bool,
60}
61
62/// A libp2p transport backed by `wasi:sockets/tcp`.
63///
64/// # Host requirements
65///
66/// The WASI host must grant network access to the component.  Under Wasmtime,
67/// pass `-S inherit-network` (or `--wasi inherit-network`).  Without it, all
68/// dials fail with [`Error::AccessDenied`] and listeners cannot be bound.
69pub struct WasiTcpTransport {
70    #[allow(dead_code)] // applied in M1 when nodelay/keep_alive socket options are set
71    config: Config,
72    #[cfg(target_arch = "wasm32")]
73    listeners: HashMap<ListenerId, ListenerState>,
74    #[cfg(not(target_arch = "wasm32"))]
75    _phantom: std::marker::PhantomData<()>,
76}
77
78// SAFETY: wasm32-wasip2 is single-threaded; WASI resource handles are safe to
79// "send" across the non-existent thread boundary.
80#[cfg(target_arch = "wasm32")]
81unsafe impl Send for WasiTcpTransport {}
82#[cfg(target_arch = "wasm32")]
83unsafe impl Sync for WasiTcpTransport {}
84
85impl WasiTcpTransport {
86    /// Create a transport with the given [`Config`].
87    pub fn new(config: Config) -> Self {
88        Self {
89            config,
90            #[cfg(target_arch = "wasm32")]
91            listeners: HashMap::new(),
92            #[cfg(not(target_arch = "wasm32"))]
93            _phantom: std::marker::PhantomData,
94        }
95    }
96}
97
98impl Default for WasiTcpTransport {
99    fn default() -> Self {
100        Self::new(Config::default())
101    }
102}
103
104impl Transport for WasiTcpTransport {
105    type Output = WasiTcpStream;
106    type Error = Error;
107    /// The upgrade is immediate: the accepted stream is already a connected
108    /// byte-stream; no further handshake is required at the transport layer.
109    type ListenerUpgrade = futures::future::Ready<Result<Self::Output, Self::Error>>;
110    /// A boxed, non-Send future — adequate for a single-threaded wasm32 executor.
111    #[cfg(target_arch = "wasm32")]
112    type Dial = WasmBoxFut<Result<Self::Output, Self::Error>>;
113    #[cfg(not(target_arch = "wasm32"))]
114    type Dial = futures::future::Pending<Result<Self::Output, Self::Error>>;
115
116    fn listen_on(
117        &mut self,
118        id: ListenerId,
119        addr: Multiaddr,
120    ) -> Result<(), TransportError<Self::Error>> {
121        let sock_addr = multiaddr_to_socketaddr(&addr).map_err(TransportError::Other)?;
122
123        #[cfg(target_arch = "wasm32")]
124        {
125            let addr_str = sock_addr.to_string();
126            let bind_fut: WasmBoxFut<std::io::Result<wstd::net::TcpListener>> =
127                Box::pin(async move { wstd::net::TcpListener::bind(&addr_str).await });
128
129            self.listeners.insert(
130                id,
131                ListenerState {
132                    bind_addr: sock_addr,
133                    listener: None,
134                    bind_future: Some(bind_fut),
135                    accept_future: None,
136                    announced: false,
137                    closing: false,
138                },
139            );
140        }
141
142        #[cfg(not(target_arch = "wasm32"))]
143        {
144            let _ = (id, sock_addr);
145        }
146
147        Ok(())
148    }
149
150    fn remove_listener(&mut self, id: ListenerId) -> bool {
151        #[cfg(target_arch = "wasm32")]
152        {
153            if let Some(state) = self.listeners.get_mut(&id) {
154                state.closing = true;
155                true
156            } else {
157                false
158            }
159        }
160        #[cfg(not(target_arch = "wasm32"))]
161        {
162            let _ = id;
163            false
164        }
165    }
166
167    fn dial(
168        &mut self,
169        addr: Multiaddr,
170        _opts: DialOpts,
171    ) -> Result<Self::Dial, TransportError<Self::Error>> {
172        let sock_addr = multiaddr_to_socketaddr(&addr).map_err(TransportError::Other)?;
173        let _ = &sock_addr; // used below only on wasm32
174
175        #[cfg(target_arch = "wasm32")]
176        {
177            let dial_fut: WasmBoxFut<Result<WasiTcpStream, Error>> =
178                Box::pin(async move {
179                    wstd::net::TcpStream::connect(sock_addr)
180                        .await
181                        .map(WasiTcpStream::new)
182                        .map_err(|e| {
183                            if e.kind() == std::io::ErrorKind::PermissionDenied {
184                                warn!(
185                                    "Network capability denied — pass `-S inherit-network` \
186                                     to wasmtime to grant the component network access."
187                                );
188                                Error::AccessDenied
189                            } else {
190                                Error::Io(e)
191                            }
192                        })
193                });
194            return Ok(dial_fut);
195        }
196
197        #[cfg(not(target_arch = "wasm32"))]
198        Err(TransportError::Other(Error::UnsupportedMultiaddr(addr)))
199    }
200
201    fn poll(
202        self: Pin<&mut Self>,
203        #[allow(unused_variables)] cx: &mut Context<'_>,
204    ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
205        #[cfg(target_arch = "wasm32")]
206        {
207            let this = self.get_mut();
208            let ids: Vec<ListenerId> = this.listeners.keys().cloned().collect();
209
210            for id in ids {
211                let state = this.listeners.get_mut(&id).unwrap();
212
213                // ── Phase 0: handle closing listeners ─────────────────────────
214                //
215                // Sequence: AddressExpired (if previously announced) → ListenerClosed.
216                // We re-use `announced` as the "AddressExpired not yet sent" flag:
217                // after emitting AddressExpired we set it to false so the next poll
218                // emits ListenerClosed and removes the entry.
219                if state.closing {
220                    state.bind_future = None;
221                    state.accept_future = None;
222                    if state.announced {
223                        let addr = state
224                            .listener
225                            .as_ref()
226                            .and_then(|l| l.local_addr().ok())
227                            .map(socketaddr_to_multiaddr)
228                            .unwrap_or_else(|| socketaddr_to_multiaddr(state.bind_addr));
229                        state.announced = false;
230                        return Poll::Ready(TransportEvent::AddressExpired {
231                            listener_id: id,
232                            listen_addr: addr,
233                        });
234                    }
235                    // AddressExpired already sent (or was never announced).
236                    let _ = state; // end the mutable borrow before remove
237                    this.listeners.remove(&id);
238                    return Poll::Ready(TransportEvent::ListenerClosed {
239                        listener_id: id,
240                        reason: Ok(()),
241                    });
242                }
243
244                // ── Phase 1: drive the bind future ────────────────────────────
245                if let Some(ref mut bind_fut) = state.bind_future {
246                    match bind_fut.as_mut().poll(cx) {
247                        Poll::Pending => continue,
248                        Poll::Ready(Err(e)) => {
249                            state.bind_future = None;
250                            let err = if e.kind() == std::io::ErrorKind::PermissionDenied {
251                                Error::AccessDenied
252                            } else {
253                                Error::Io(e)
254                            };
255                            return Poll::Ready(TransportEvent::ListenerError {
256                                listener_id: id,
257                                error: err,
258                            });
259                        }
260                        Poll::Ready(Ok(listener)) => {
261                            let local_addr = listener
262                                .local_addr()
263                                .map(socketaddr_to_multiaddr)
264                                .unwrap_or_else(|_| socketaddr_to_multiaddr(state.bind_addr));
265                            state.listener = Some(Arc::new(listener));
266                            state.bind_future = None;
267                            state.announced = true;
268                            return Poll::Ready(TransportEvent::NewAddress {
269                                listener_id: id,
270                                listen_addr: local_addr,
271                            });
272                        }
273                    }
274                }
275
276                // ── Phase 2: accept loop ───────────────────────────────────────
277                let Some(listener_arc) = state.listener.as_ref().map(Arc::clone) else {
278                    continue;
279                };
280
281                if state.accept_future.is_none() {
282                    let listener = Arc::clone(&listener_arc);
283                    state.accept_future = Some(Box::pin(async move {
284                        use wstd::iter::AsyncIterator as _;
285                        listener
286                            .incoming()
287                            .next()
288                            .await
289                            .unwrap_or_else(|| {
290                                Err(std::io::Error::new(
291                                    std::io::ErrorKind::BrokenPipe,
292                                    "listener closed",
293                                ))
294                            })
295                    }));
296                }
297
298                if let Some(ref mut accept_fut) = state.accept_future {
299                    match accept_fut.as_mut().poll(cx) {
300                        Poll::Pending => {}
301                        Poll::Ready(Err(e)) => {
302                            state.accept_future = None;
303                            return Poll::Ready(TransportEvent::ListenerError {
304                                listener_id: id,
305                                error: Error::Io(e),
306                            });
307                        }
308                        Poll::Ready(Ok(tcp_stream)) => {
309                            state.accept_future = None;
310                            let local_addr = listener_arc
311                                .local_addr()
312                                .map(socketaddr_to_multiaddr)
313                                .unwrap_or_else(|_| socketaddr_to_multiaddr(state.bind_addr));
314                            // send_back_addr: wstd's TcpStream::peer_addr() returns a debug
315                            // string (not a SocketAddr).  For v0.1.0 we use the listen addr as a
316                            // placeholder.  Tracking issue: add proper peer-addr extraction.
317                            let send_back_addr = local_addr.clone();
318                            let wasi_stream = WasiTcpStream::new(tcp_stream);
319                            return Poll::Ready(TransportEvent::Incoming {
320                                listener_id: id,
321                                upgrade: futures::future::ready(Ok(wasi_stream)),
322                                local_addr,
323                                send_back_addr,
324                            });
325                        }
326                    }
327                }
328            }
329
330            Poll::Pending
331        }
332
333        #[cfg(not(target_arch = "wasm32"))]
334        Poll::Pending
335    }
336}