async_wasi/snapshots/common/net/
mod.rs

1#[cfg(all(unix, feature = "async_tokio"))]
2pub mod async_tokio;
3
4pub use super::vfs::*;
5
6use super::{
7    error::Errno,
8    types::{self as wasi_types, __wasi_subscription_t},
9};
10use std::{
11    future::Future,
12    io::{self, Read, Write},
13    net,
14    time::{Duration, SystemTime},
15};
16use wasi_types::{
17    __wasi_clockid_t::{
18        __WASI_CLOCKID_MONOTONIC as CLOCKID_MONOTONIC, __WASI_CLOCKID_REALTIME as CLOCKID_REALTIME,
19    },
20    __wasi_eventtype_t::{
21        __WASI_EVENTTYPE_CLOCK as CLOCK, __WASI_EVENTTYPE_FD_READ as RD,
22        __WASI_EVENTTYPE_FD_WRITE as WR,
23    },
24};
25
26#[derive(Debug, Clone, Copy, Default)]
27pub enum AddressFamily {
28    #[default]
29    Inet4,
30    Inet6,
31}
32
33#[derive(Debug, Clone, Copy, Default)]
34pub enum SocketType {
35    Datagram,
36    #[default]
37    Stream,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum ConnectState {
42    Empty,
43    Listening,
44    Connected,
45    Connecting,
46}
47
48impl Default for ConnectState {
49    fn default() -> Self {
50        Self::Empty
51    }
52}
53
54#[derive(Debug, Clone, Default)]
55pub struct WasiSocketState {
56    pub sock_type: (AddressFamily, SocketType),
57    pub local_addr: Option<net::SocketAddr>,
58    pub peer_addr: Option<net::SocketAddr>,
59    pub bind_device: Vec<u8>,
60    pub backlog: u32,
61    pub shutdown: Option<net::Shutdown>,
62    pub nonblocking: bool,
63    pub so_reuseaddr: bool,
64    pub so_conn_state: ConnectState,
65    pub so_recv_buf_size: usize,
66    pub so_send_buf_size: usize,
67    pub so_recv_timeout: Option<Duration>,
68    pub so_send_timeout: Option<Duration>,
69    pub fs_rights: WASIRights,
70}
71
72#[derive(Debug, Clone, Copy)]
73pub enum SubscriptionFdType {
74    Read(wasi_types::__wasi_userdata_t),
75    Write(wasi_types::__wasi_userdata_t),
76    Both {
77        read: wasi_types::__wasi_userdata_t,
78        write: wasi_types::__wasi_userdata_t,
79    },
80}
81
82#[derive(Debug, Clone, Copy)]
83pub struct SubscriptionFd {
84    pub fd: wasi_types::__wasi_fd_t,
85    pub type_: SubscriptionFdType,
86}
87
88impl SubscriptionFd {
89    pub fn set_write(&mut self, userdata: wasi_types::__wasi_userdata_t) {
90        let read_userdata = match &mut self.type_ {
91            SubscriptionFdType::Read(v) => *v,
92            SubscriptionFdType::Write(v) => {
93                *v = userdata;
94                return;
95            }
96            SubscriptionFdType::Both { read, write } => {
97                *write = userdata;
98                return;
99            }
100        };
101        self.type_ = SubscriptionFdType::Both {
102            read: read_userdata,
103            write: userdata,
104        };
105    }
106
107    pub fn set_read(&mut self, userdata: wasi_types::__wasi_userdata_t) {
108        let write_userdata = match &mut self.type_ {
109            SubscriptionFdType::Write(v) => *v,
110            SubscriptionFdType::Read(v) => {
111                *v = userdata;
112                return;
113            }
114            SubscriptionFdType::Both { read, write } => {
115                *read = userdata;
116                return;
117            }
118        };
119        self.type_ = SubscriptionFdType::Both {
120            read: userdata,
121            write: write_userdata,
122        };
123    }
124}
125
126#[derive(Debug, Clone, Copy)]
127pub struct SubscriptionClock {
128    pub timeout: Option<SystemTime>,
129    pub userdata: wasi_types::__wasi_userdata_t,
130    pub err: Option<Errno>,
131}
132
133#[derive(Debug, Clone, Copy)]
134pub enum Subscription {
135    FD(SubscriptionFd),
136    RealClock(SubscriptionClock),
137}
138
139impl Subscription {
140    pub fn from(s: &__wasi_subscription_t) -> Result<Subscription, Errno> {
141        let userdata = s.userdata;
142        match s.u.tag {
143            CLOCK => {
144                let clock = unsafe { s.u.u.clock };
145                match clock.id {
146                    CLOCKID_REALTIME | CLOCKID_MONOTONIC => {
147                        if clock.flags == 1 {
148                            if let Some(ddl) = std::time::UNIX_EPOCH
149                                .checked_add(Duration::from_nanos(clock.timeout + clock.precision))
150                            {
151                                Ok(Subscription::RealClock(SubscriptionClock {
152                                    timeout: Some(ddl),
153                                    userdata,
154                                    err: None,
155                                }))
156                            } else {
157                                Ok(Subscription::RealClock(SubscriptionClock {
158                                    timeout: None,
159                                    userdata,
160                                    err: Some(Errno::__WASI_ERRNO_INVAL),
161                                }))
162                            }
163                        } else if clock.timeout == 0 {
164                            Ok(Subscription::RealClock(SubscriptionClock {
165                                timeout: None,
166                                userdata,
167                                err: None,
168                            }))
169                        } else {
170                            let duration = Duration::from_nanos(clock.timeout + clock.precision);
171
172                            let timeout = std::time::SystemTime::now().checked_add(duration);
173
174                            Ok(Subscription::RealClock(SubscriptionClock {
175                                timeout,
176                                userdata,
177                                err: None,
178                            }))
179                        }
180                    }
181
182                    _ => Ok(Subscription::RealClock(SubscriptionClock {
183                        timeout: None,
184                        userdata,
185                        err: Some(Errno::__WASI_ERRNO_NODEV),
186                    })),
187                }
188            }
189            RD => {
190                let fd_read = unsafe { s.u.u.fd_read };
191                Ok(Subscription::FD(SubscriptionFd {
192                    fd: fd_read.file_descriptor,
193                    type_: SubscriptionFdType::Read(userdata),
194                }))
195            }
196            WR => {
197                let fd_read = unsafe { s.u.u.fd_read };
198                Ok(Subscription::FD(SubscriptionFd {
199                    fd: fd_read.file_descriptor,
200                    type_: SubscriptionFdType::Write(userdata),
201                }))
202            }
203            _ => Err(Errno::__WASI_ERRNO_INVAL),
204        }
205    }
206}
207
208#[derive(Debug)]
209pub enum PrePoll {
210    OnlyFd(Vec<SubscriptionFd>),
211    OnlyClock(SubscriptionClock),
212    ClockAndFd(SubscriptionClock, Vec<SubscriptionFd>),
213}
214
215impl PrePoll {
216    pub fn from_wasi_subscription(
217        subs: &[wasi_types::__wasi_subscription_t],
218    ) -> Result<Self, Errno> {
219        use std::collections::HashMap;
220        let mut fds = HashMap::with_capacity(subs.len());
221
222        let mut timeout: Option<SubscriptionClock> = None;
223        for s in subs {
224            let s = Subscription::from(s)?;
225            match s {
226                Subscription::FD(fd) => {
227                    let type_ = fd.type_;
228
229                    fds.entry(fd.fd)
230                        .and_modify(|e: &mut SubscriptionFd| match type_ {
231                            SubscriptionFdType::Read(data) => e.set_read(data),
232                            SubscriptionFdType::Write(data) => e.set_write(data),
233                            SubscriptionFdType::Both { read, write } => {
234                                e.type_ = SubscriptionFdType::Both { read, write };
235                            }
236                        })
237                        .or_insert(fd);
238                }
239                Subscription::RealClock(clock) => {
240                    if clock.err.is_some() {
241                        return Ok(PrePoll::OnlyClock(clock));
242                    }
243                    if clock.timeout.is_none() {
244                        return Ok(PrePoll::OnlyClock(clock));
245                    }
246
247                    if let Some(old_clock) = &mut timeout {
248                        let new_timeout = clock.timeout.unwrap();
249                        let old_timeout = old_clock.timeout.unwrap();
250
251                        if new_timeout < old_timeout {
252                            *old_clock = clock
253                        }
254                    } else {
255                        timeout = Some(clock)
256                    }
257                }
258            }
259        }
260
261        let fd_vec: Vec<SubscriptionFd> = fds.into_values().collect();
262
263        if let Some(clock) = timeout {
264            if fd_vec.is_empty() {
265                Ok(PrePoll::OnlyClock(clock))
266            } else {
267                Ok(PrePoll::ClockAndFd(clock, fd_vec))
268            }
269        } else {
270            Ok(PrePoll::OnlyFd(fd_vec))
271        }
272    }
273}