madsim/sim/net/
mod.rs

1//! Asynchronous network endpoint and a controlled network simulator.
2//!
3//! # Examples
4//!
5//! ```
6//! use madsim::{runtime::Runtime, net::Endpoint};
7//! use std::sync::Arc;
8//! use std::net::SocketAddr;
9//!
10//! let runtime = Runtime::new();
11//! let addr1 = "10.0.0.1:1".parse::<SocketAddr>().unwrap();
12//! let addr2 = "10.0.0.2:1".parse::<SocketAddr>().unwrap();
13//! let node1 = runtime.create_node().ip(addr1.ip()).build();
14//! let node2 = runtime.create_node().ip(addr2.ip()).build();
15//! let barrier = Arc::new(tokio::sync::Barrier::new(2));
16//! let barrier_ = barrier.clone();
17//!
18//! node1.spawn(async move {
19//!     let net = Endpoint::bind(addr1).await.unwrap();
20//!     barrier_.wait().await;  // make sure addr2 has bound
21//!
22//!     net.send_to(addr2, 1, &[1]).await.unwrap();
23//! });
24//!
25//! let f = node2.spawn(async move {
26//!     let net = Endpoint::bind(addr2).await.unwrap();
27//!     barrier.wait().await;
28//!
29//!     let mut buf = vec![0; 0x10];
30//!     let (len, from) = net.recv_from(1, &mut buf).await.unwrap();
31//!     assert_eq!(from, addr1);
32//!     assert_eq!(&buf[..len], &[1]);
33//! });
34//!
35//! runtime.block_on(f);
36//! ```
37
38use bytes::Bytes;
39use futures_util::{stream::BoxStream, StreamExt};
40use spin::Mutex;
41use std::{
42    any::Any,
43    collections::HashMap,
44    io,
45    net::{IpAddr, SocketAddr},
46    sync::Arc,
47    time::Instant,
48};
49use tokio::sync::{mpsc, oneshot};
50use tracing::*;
51
52use crate::{
53    buggify::buggify_with_prob,
54    plugin,
55    rand::{GlobalRng, Rng},
56    task::{NodeId, NodeInfo, Spawner},
57    time::{sleep, sleep_until, Duration, TimeHandle},
58};
59
60mod addr;
61mod dns;
62mod endpoint;
63pub mod ipvs;
64mod network;
65#[cfg(feature = "rpc")]
66#[cfg_attr(docsrs, doc(cfg(feature = "rpc")))]
67pub mod rpc;
68pub mod tcp;
69mod udp;
70pub mod unix;
71
72pub use self::addr::{lookup_host, ToSocketAddrs};
73use self::dns::DnsServer;
74pub use self::endpoint::{Endpoint, Receiver, Sender};
75use self::ipvs::{IpVirtualServer, ServiceAddr};
76pub use self::network::{Config, Stat};
77use self::network::{Direction, IpProtocol, Network, Socket};
78pub use self::tcp::{TcpListener, TcpStream};
79pub use self::udp::UdpSocket;
80pub use self::unix::{UnixDatagram, UnixListener, UnixStream};
81
82/// Network simulator.
83#[cfg_attr(docsrs, doc(cfg(madsim)))]
84pub struct NetSim {
85    network: Mutex<Network>,
86    dns: Mutex<DnsServer>,
87    ipvs: IpVirtualServer,
88    rand: GlobalRng,
89    time: TimeHandle,
90    hooks_req: Mutex<HashMap<NodeId, MsgHookFn>>,
91    hooks_rsp: Mutex<HashMap<NodeId, MsgHookFn>>,
92}
93
94/// Message sent to a network socket.
95pub type Payload = Box<dyn Any + Send + Sync>;
96
97type MsgHookFn = Arc<dyn Fn(&Payload) -> bool + Send + Sync>;
98
99impl plugin::Simulator for NetSim {
100    fn new(_rand: &GlobalRng, _time: &TimeHandle, _config: &crate::Config) -> Self {
101        unreachable!()
102    }
103
104    fn new1(rand: &GlobalRng, time: &TimeHandle, _task: &Spawner, config: &crate::Config) -> Self {
105        NetSim {
106            network: Mutex::new(Network::new(rand.clone(), config.net.clone())),
107            dns: Mutex::new(DnsServer::default()),
108            ipvs: IpVirtualServer::default(),
109            rand: rand.clone(),
110            time: time.clone(),
111            hooks_req: Default::default(),
112            hooks_rsp: Default::default(),
113        }
114    }
115
116    fn create_node(&self, id: NodeId) {
117        let mut network = self.network.lock();
118        network.insert_node(id);
119    }
120
121    fn reset_node(&self, id: NodeId) {
122        self.reset_node(id);
123    }
124}
125
126impl NetSim {
127    /// Get [`NetSim`] of the current simulator.
128    pub fn current() -> Arc<Self> {
129        plugin::simulator()
130    }
131
132    /// Get the statistics.
133    pub fn stat(&self) -> Stat {
134        self.network.lock().stat().clone()
135    }
136
137    /// Update network configurations.
138    pub fn update_config(&self, f: impl FnOnce(&mut Config)) {
139        let mut network = self.network.lock();
140        network.update_config(f);
141    }
142
143    /// Reset a node.
144    ///
145    /// All connections will be closed.
146    pub fn reset_node(&self, id: NodeId) {
147        let mut network = self.network.lock();
148        network.reset_node(id);
149    }
150
151    /// Set IP address of a node.
152    pub fn set_ip(&self, node: NodeId, ip: IpAddr) {
153        let mut network = self.network.lock();
154        network.set_ip(node, ip);
155    }
156
157    /// Connect a node to the network.
158    #[deprecated(since = "0.3.0", note = "use `unclog_node` instead")]
159    pub fn connect(&self, id: NodeId) {
160        self.unclog_node(id);
161    }
162
163    /// Unclog the node.
164    pub fn unclog_node(&self, id: NodeId) {
165        self.network.lock().unclog_node(id, Direction::Both);
166    }
167
168    /// Unclog the node for receive.
169    pub fn unclog_node_in(&self, id: NodeId) {
170        self.network.lock().unclog_node(id, Direction::In);
171    }
172
173    /// Unclog the node for send.
174    pub fn unclog_node_out(&self, id: NodeId) {
175        self.network.lock().unclog_node(id, Direction::Out);
176    }
177
178    /// Disconnect a node from the network.
179    #[deprecated(since = "0.3.0", note = "use `clog_node` instead")]
180    pub fn disconnect(&self, id: NodeId) {
181        self.clog_node(id);
182    }
183
184    /// Clog the node.
185    pub fn clog_node(&self, id: NodeId) {
186        self.network.lock().clog_node(id, Direction::Both);
187    }
188
189    /// Clog the node for receive.
190    pub fn clog_node_in(&self, id: NodeId) {
191        self.network.lock().clog_node(id, Direction::In);
192    }
193
194    /// Clog the node for send.
195    pub fn clog_node_out(&self, id: NodeId) {
196        self.network.lock().clog_node(id, Direction::Out);
197    }
198
199    /// Connect a pair of nodes.
200    #[deprecated(since = "0.3.0", note = "call `unclog_link` twice instead")]
201    pub fn connect2(&self, node1: NodeId, node2: NodeId) {
202        let mut network = self.network.lock();
203        network.unclog_link(node1, node2);
204        network.unclog_link(node2, node1);
205    }
206
207    /// Unclog the link from `src` to `dst`.
208    pub fn unclog_link(&self, src: NodeId, dst: NodeId) {
209        self.network.lock().unclog_link(src, dst);
210    }
211
212    /// Disconnect a pair of nodes.
213    #[deprecated(since = "0.3.0", note = "call `clog_link` twice instead")]
214    pub fn disconnect2(&self, node1: NodeId, node2: NodeId) {
215        let mut network = self.network.lock();
216        network.clog_link(node1, node2);
217        network.clog_link(node2, node1);
218    }
219
220    /// Clog the link from `src` to `dst`.
221    pub fn clog_link(&self, src: NodeId, dst: NodeId) {
222        self.network.lock().clog_link(src, dst);
223    }
224
225    /// Add a DNS record for the cluster.
226    pub fn add_dns_record(&self, hostname: &str, ip: IpAddr) {
227        self.dns.lock().add(hostname, ip);
228    }
229
230    /// Performs a DNS lookup.
231    pub(crate) fn lookup_host(&self, hostname: &str) -> Option<IpAddr> {
232        self.dns.lock().lookup(hostname)
233    }
234
235    /// Get the IPVS for all nodes.
236    pub fn global_ipvs(&self) -> &IpVirtualServer {
237        &self.ipvs
238    }
239
240    /// Add a hook function for RPC requests.
241    ///
242    /// If the hook function returns `false`, the request will be dropped.
243    #[cfg(feature = "rpc")]
244    #[cfg_attr(docsrs, doc(cfg(feature = "rpc")))]
245    pub fn hook_rpc_req<R: 'static>(
246        &self,
247        node: NodeId,
248        f: impl Fn(&R) -> bool + Send + Sync + 'static,
249    ) {
250        self.hooks_req.lock().insert(
251            node,
252            Arc::new(move |payload| {
253                if let Some((_, payload)) = payload.downcast_ref::<(u64, Payload)>() {
254                    if let Some((_, msg, _)) = payload.downcast_ref::<(u64, R, Bytes)>() {
255                        return f(msg);
256                    }
257                }
258                true
259            }),
260        );
261    }
262
263    /// Add a hook function for RPC responses.
264    ///
265    /// If the hook function returns `false`, the response will be dropped.
266    #[cfg(feature = "rpc")]
267    #[cfg_attr(docsrs, doc(cfg(feature = "rpc")))]
268    pub fn hook_rpc_rsp<R: 'static>(
269        &self,
270        node: NodeId,
271        f: impl Fn(&R) -> bool + Send + Sync + 'static,
272    ) {
273        self.hooks_rsp.lock().insert(
274            node,
275            Arc::new(move |payload| {
276                if let Some((_, payload)) = payload.downcast_ref::<(u64, Payload)>() {
277                    if let Some((msg, _)) = payload.downcast_ref::<(R, Bytes)>() {
278                        return f(msg);
279                    }
280                }
281                true
282            }),
283        );
284    }
285
286    /// Delay a small random time and probably inject failure.
287    async fn rand_delay(&self) -> io::Result<()> {
288        let mut delay = Duration::from_micros(self.rand.with(|rng| rng.gen_range(0..5)));
289        if buggify_with_prob(0.1) {
290            delay = Duration::from_secs(self.rand.with(|rng| rng.gen_range(1..5)));
291        }
292        self.time.sleep(delay).await;
293        // TODO: inject failure
294        Ok(())
295    }
296
297    /// Send a message to the destination.
298    pub(crate) async fn send(
299        &self,
300        node: NodeId,
301        port: u16,
302        mut dst: SocketAddr,
303        protocol: IpProtocol,
304        msg: Payload,
305    ) -> io::Result<()> {
306        self.rand_delay().await?;
307        if let Some(hook) = self.hooks_req.lock().get(&node).cloned() {
308            if !hook(&msg) {
309                return Ok(());
310            }
311        }
312        if let Some(addr) = self
313            .ipvs
314            .get_server(ServiceAddr::from_addr_proto(dst, protocol))
315        {
316            dst = addr.parse().expect("invalid socket address");
317        }
318        if let Some((ip, dst_node, socket, latency)) =
319            self.network.lock().try_send(node, dst, protocol)
320        {
321            trace!(?latency, "delay");
322            let hook = self.hooks_rsp.lock().get(&dst_node).cloned();
323            self.time.add_timer(latency, move || {
324                if let Some(hook) = hook {
325                    if !hook(&msg) {
326                        return;
327                    }
328                }
329                socket.deliver((ip, port).into(), dst, msg);
330            });
331        }
332        Ok(())
333    }
334
335    /// Opens a new connection to destination.
336    // TODO: rename
337    pub(crate) async fn connect1(
338        self: &Arc<Self>,
339        node: NodeId,
340        port: u16,
341        mut dst: SocketAddr,
342        protocol: IpProtocol,
343    ) -> io::Result<(PayloadSender, PayloadReceiver, SocketAddr)> {
344        self.rand_delay().await?;
345        if let Some(addr) = self
346            .ipvs
347            .get_server(ServiceAddr::from_addr_proto(dst, protocol))
348        {
349            dst = addr.parse().expect("invalid socket address");
350        }
351        let (ip, dst_node, socket, latency) = (self.network.lock().try_send(node, dst, protocol))
352            .ok_or_else(|| {
353            io::Error::new(io::ErrorKind::ConnectionRefused, "connection refused")
354        })?;
355        let src = (ip, port).into();
356        let (tx1, rx1) = self.channel(node, dst, protocol);
357        let (tx2, rx2) = self.channel(dst_node, src, protocol);
358        trace!(?latency, "delay");
359        // FIXME: delay
360        // self.time.add_timer(latency, move || {
361        socket.new_connection(src, dst, tx2, rx1);
362        // });
363        Ok((tx1, rx2, src))
364    }
365
366    /// Create a reliable, ordered channel between two endpoints.
367    fn channel(
368        self: &Arc<Self>,
369        node: NodeId,
370        dst: SocketAddr,
371        protocol: IpProtocol,
372    ) -> (PayloadSender, PayloadReceiver) {
373        let (tx, mut rx) = mpsc::unbounded_channel();
374        let net = self.clone();
375        let test_link = Arc::new(move || {
376            net.network
377                .lock()
378                .try_send(node, dst, protocol)
379                .map(|(_, _, _, latency)| net.time.now_instant() + latency)
380        });
381        let sender = PayloadSender {
382            test_link: test_link.clone(),
383            tx,
384        };
385        let recver = async_stream::stream! {
386            while let Some((value, mut state)) = rx.recv().await {
387                // wait until the link is ready
388                let mut backoff = Duration::from_millis(1);
389                let arrive_time = loop {
390                    if let Some(arrive_time) = state {
391                        break arrive_time;
392                    }
393                    // backoff
394                    sleep(backoff).await;
395                    backoff = (backoff * 2).min(Duration::from_secs(10));
396                    // retry
397                    state = test_link();
398                };
399                sleep_until(arrive_time).await;
400                yield value;
401            }
402        }
403        .boxed();
404        (sender, recver)
405    }
406}
407
408#[doc(hidden)]
409pub struct PayloadSender {
410    test_link: Arc<dyn Fn() -> State + Send + Sync>,
411    tx: mpsc::UnboundedSender<(Payload, State)>,
412}
413
414/// The link state when sending a packet.
415type State = Option<Instant>;
416
417impl PayloadSender {
418    fn send(&self, value: Payload) -> Option<()> {
419        let state = (self.test_link)();
420        self.tx.send((value, state)).ok()
421    }
422
423    fn is_closed(&self) -> bool {
424        self.tx.is_closed()
425    }
426
427    async fn closed(&self) {
428        self.tx.closed().await;
429    }
430}
431
432#[doc(hidden)]
433pub type PayloadReceiver = BoxStream<'static, Payload>;
434
435/// An RAII structure used to release the bound port.
436pub(crate) struct BindGuard {
437    net: Arc<NetSim>,
438    node: Arc<NodeInfo>,
439    /// Bound address.
440    addr: SocketAddr,
441    protocol: IpProtocol,
442}
443
444impl BindGuard {
445    /// Bind a socket to the address.
446    pub async fn bind(
447        addr: impl ToSocketAddrs,
448        protocol: IpProtocol,
449        socket: Arc<dyn Socket>,
450    ) -> io::Result<Self> {
451        let net = plugin::simulator::<NetSim>();
452        let node = crate::context::current_task().node.clone();
453
454        // attempt to bind to each address
455        let mut last_err = None;
456        for addr in lookup_host(addr).await? {
457            net.rand_delay().await?;
458            match net
459                .network
460                .lock()
461                .bind(node.id, addr, protocol, socket.clone())
462            {
463                Ok(addr) => {
464                    return Ok(BindGuard {
465                        net: net.clone(),
466                        node,
467                        addr,
468                        protocol,
469                    })
470                }
471                Err(e) => last_err = Some(e),
472            }
473        }
474        Err(last_err.unwrap_or_else(|| {
475            io::Error::new(
476                io::ErrorKind::InvalidInput,
477                "could not resolve to any addresses",
478            )
479        }))
480    }
481}
482
483impl Drop for BindGuard {
484    fn drop(&mut self) {
485        // avoid interfering with restarted node
486        if self.node.is_killed() {
487            return;
488        }
489        // avoid panic on panicking
490        if let Some(mut network) = self.net.network.try_lock() {
491            network.close(self.node.id, self.addr, self.protocol);
492        }
493    }
494}