libp2p_wasi_sockets/
transport.rs1use std::collections::HashMap;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use libp2p_core::multiaddr::Multiaddr;
8use libp2p_core::transport::{DialOpts, ListenerId, TransportError, TransportEvent};
9use libp2p_core::Transport;
10use tracing::warn;
11
12use crate::error::Error;
13use crate::multiaddr::{multiaddr_to_socketaddr, socketaddr_to_multiaddr};
14use crate::stream::WasiTcpStream;
15
16#[derive(Debug, Clone)]
18pub struct Config {
19 pub nodelay: bool,
21 pub keep_alive: Option<Duration>,
23 pub listen_backlog: u32,
25}
26
27impl Default for Config {
28 fn default() -> Self {
29 Self {
30 nodelay: true,
31 keep_alive: None,
32 listen_backlog: 128,
33 }
34 }
35}
36
37#[cfg(target_arch = "wasm32")]
41type WasmBoxFut<T> = Pin<Box<dyn std::future::Future<Output = T>>>;
42
43#[cfg(target_arch = "wasm32")]
45struct ListenerState {
46 bind_addr: std::net::SocketAddr,
47 listener: Option<Arc<wstd::net::TcpListener>>,
49 bind_future: Option<WasmBoxFut<std::io::Result<wstd::net::TcpListener>>>,
51 accept_future: Option<WasmBoxFut<std::io::Result<wstd::net::TcpStream>>>,
53 announced: bool,
57 closing: bool,
60}
61
62pub struct WasiTcpTransport {
70 #[allow(dead_code)] config: Config,
72 #[cfg(target_arch = "wasm32")]
73 listeners: HashMap<ListenerId, ListenerState>,
74 #[cfg(not(target_arch = "wasm32"))]
75 _phantom: std::marker::PhantomData<()>,
76}
77
78#[cfg(target_arch = "wasm32")]
81unsafe impl Send for WasiTcpTransport {}
82#[cfg(target_arch = "wasm32")]
83unsafe impl Sync for WasiTcpTransport {}
84
85impl WasiTcpTransport {
86 pub fn new(config: Config) -> Self {
88 Self {
89 config,
90 #[cfg(target_arch = "wasm32")]
91 listeners: HashMap::new(),
92 #[cfg(not(target_arch = "wasm32"))]
93 _phantom: std::marker::PhantomData,
94 }
95 }
96}
97
98impl Default for WasiTcpTransport {
99 fn default() -> Self {
100 Self::new(Config::default())
101 }
102}
103
104impl Transport for WasiTcpTransport {
105 type Output = WasiTcpStream;
106 type Error = Error;
107 type ListenerUpgrade = futures::future::Ready<Result<Self::Output, Self::Error>>;
110 #[cfg(target_arch = "wasm32")]
112 type Dial = WasmBoxFut<Result<Self::Output, Self::Error>>;
113 #[cfg(not(target_arch = "wasm32"))]
114 type Dial = futures::future::Pending<Result<Self::Output, Self::Error>>;
115
116 fn listen_on(
117 &mut self,
118 id: ListenerId,
119 addr: Multiaddr,
120 ) -> Result<(), TransportError<Self::Error>> {
121 let sock_addr = multiaddr_to_socketaddr(&addr).map_err(TransportError::Other)?;
122
123 #[cfg(target_arch = "wasm32")]
124 {
125 let addr_str = sock_addr.to_string();
126 let bind_fut: WasmBoxFut<std::io::Result<wstd::net::TcpListener>> =
127 Box::pin(async move { wstd::net::TcpListener::bind(&addr_str).await });
128
129 self.listeners.insert(
130 id,
131 ListenerState {
132 bind_addr: sock_addr,
133 listener: None,
134 bind_future: Some(bind_fut),
135 accept_future: None,
136 announced: false,
137 closing: false,
138 },
139 );
140 }
141
142 #[cfg(not(target_arch = "wasm32"))]
143 {
144 let _ = (id, sock_addr);
145 }
146
147 Ok(())
148 }
149
150 fn remove_listener(&mut self, id: ListenerId) -> bool {
151 #[cfg(target_arch = "wasm32")]
152 {
153 if let Some(state) = self.listeners.get_mut(&id) {
154 state.closing = true;
155 true
156 } else {
157 false
158 }
159 }
160 #[cfg(not(target_arch = "wasm32"))]
161 {
162 let _ = id;
163 false
164 }
165 }
166
167 fn dial(
168 &mut self,
169 addr: Multiaddr,
170 _opts: DialOpts,
171 ) -> Result<Self::Dial, TransportError<Self::Error>> {
172 let sock_addr = multiaddr_to_socketaddr(&addr).map_err(TransportError::Other)?;
173 let _ = &sock_addr; #[cfg(target_arch = "wasm32")]
176 {
177 let dial_fut: WasmBoxFut<Result<WasiTcpStream, Error>> =
178 Box::pin(async move {
179 wstd::net::TcpStream::connect(sock_addr)
180 .await
181 .map(WasiTcpStream::new)
182 .map_err(|e| {
183 if e.kind() == std::io::ErrorKind::PermissionDenied {
184 warn!(
185 "Network capability denied — pass `-S inherit-network` \
186 to wasmtime to grant the component network access."
187 );
188 Error::AccessDenied
189 } else {
190 Error::Io(e)
191 }
192 })
193 });
194 return Ok(dial_fut);
195 }
196
197 #[cfg(not(target_arch = "wasm32"))]
198 Err(TransportError::Other(Error::UnsupportedMultiaddr(addr)))
199 }
200
201 fn poll(
202 self: Pin<&mut Self>,
203 #[allow(unused_variables)] cx: &mut Context<'_>,
204 ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
205 #[cfg(target_arch = "wasm32")]
206 {
207 let this = self.get_mut();
208 let ids: Vec<ListenerId> = this.listeners.keys().cloned().collect();
209
210 for id in ids {
211 let state = this.listeners.get_mut(&id).unwrap();
212
213 if state.closing {
220 state.bind_future = None;
221 state.accept_future = None;
222 if state.announced {
223 let addr = state
224 .listener
225 .as_ref()
226 .and_then(|l| l.local_addr().ok())
227 .map(socketaddr_to_multiaddr)
228 .unwrap_or_else(|| socketaddr_to_multiaddr(state.bind_addr));
229 state.announced = false;
230 return Poll::Ready(TransportEvent::AddressExpired {
231 listener_id: id,
232 listen_addr: addr,
233 });
234 }
235 let _ = state; this.listeners.remove(&id);
238 return Poll::Ready(TransportEvent::ListenerClosed {
239 listener_id: id,
240 reason: Ok(()),
241 });
242 }
243
244 if let Some(ref mut bind_fut) = state.bind_future {
246 match bind_fut.as_mut().poll(cx) {
247 Poll::Pending => continue,
248 Poll::Ready(Err(e)) => {
249 state.bind_future = None;
250 let err = if e.kind() == std::io::ErrorKind::PermissionDenied {
251 Error::AccessDenied
252 } else {
253 Error::Io(e)
254 };
255 return Poll::Ready(TransportEvent::ListenerError {
256 listener_id: id,
257 error: err,
258 });
259 }
260 Poll::Ready(Ok(listener)) => {
261 let local_addr = listener
262 .local_addr()
263 .map(socketaddr_to_multiaddr)
264 .unwrap_or_else(|_| socketaddr_to_multiaddr(state.bind_addr));
265 state.listener = Some(Arc::new(listener));
266 state.bind_future = None;
267 state.announced = true;
268 return Poll::Ready(TransportEvent::NewAddress {
269 listener_id: id,
270 listen_addr: local_addr,
271 });
272 }
273 }
274 }
275
276 let Some(listener_arc) = state.listener.as_ref().map(Arc::clone) else {
278 continue;
279 };
280
281 if state.accept_future.is_none() {
282 let listener = Arc::clone(&listener_arc);
283 state.accept_future = Some(Box::pin(async move {
284 use wstd::iter::AsyncIterator as _;
285 listener
286 .incoming()
287 .next()
288 .await
289 .unwrap_or_else(|| {
290 Err(std::io::Error::new(
291 std::io::ErrorKind::BrokenPipe,
292 "listener closed",
293 ))
294 })
295 }));
296 }
297
298 if let Some(ref mut accept_fut) = state.accept_future {
299 match accept_fut.as_mut().poll(cx) {
300 Poll::Pending => {}
301 Poll::Ready(Err(e)) => {
302 state.accept_future = None;
303 return Poll::Ready(TransportEvent::ListenerError {
304 listener_id: id,
305 error: Error::Io(e),
306 });
307 }
308 Poll::Ready(Ok(tcp_stream)) => {
309 state.accept_future = None;
310 let local_addr = listener_arc
311 .local_addr()
312 .map(socketaddr_to_multiaddr)
313 .unwrap_or_else(|_| socketaddr_to_multiaddr(state.bind_addr));
314 let send_back_addr = local_addr.clone();
318 let wasi_stream = WasiTcpStream::new(tcp_stream);
319 return Poll::Ready(TransportEvent::Incoming {
320 listener_id: id,
321 upgrade: futures::future::ready(Ok(wasi_stream)),
322 local_addr,
323 send_back_addr,
324 });
325 }
326 }
327 }
328 }
329
330 Poll::Pending
331 }
332
333 #[cfg(not(target_arch = "wasm32"))]
334 Poll::Pending
335 }
336}