1use bytes::Bytes;
39use futures_util::{stream::BoxStream, StreamExt};
40use spin::Mutex;
41use std::{
42 any::Any,
43 collections::HashMap,
44 io,
45 net::{IpAddr, SocketAddr},
46 sync::Arc,
47 time::Instant,
48};
49use tokio::sync::{mpsc, oneshot};
50use tracing::*;
51
52use crate::{
53 buggify::buggify_with_prob,
54 plugin,
55 rand::{GlobalRng, Rng},
56 task::{NodeId, NodeInfo, Spawner},
57 time::{sleep, sleep_until, Duration, TimeHandle},
58};
59
60mod addr;
61mod dns;
62mod endpoint;
63pub mod ipvs;
64mod network;
65#[cfg(feature = "rpc")]
66#[cfg_attr(docsrs, doc(cfg(feature = "rpc")))]
67pub mod rpc;
68pub mod tcp;
69mod udp;
70pub mod unix;
71
72pub use self::addr::{lookup_host, ToSocketAddrs};
73use self::dns::DnsServer;
74pub use self::endpoint::{Endpoint, Receiver, Sender};
75use self::ipvs::{IpVirtualServer, ServiceAddr};
76pub use self::network::{Config, Stat};
77use self::network::{Direction, IpProtocol, Network, Socket};
78pub use self::tcp::{TcpListener, TcpStream};
79pub use self::udp::UdpSocket;
80pub use self::unix::{UnixDatagram, UnixListener, UnixStream};
81
82#[cfg_attr(docsrs, doc(cfg(madsim)))]
84pub struct NetSim {
85 network: Mutex<Network>,
86 dns: Mutex<DnsServer>,
87 ipvs: IpVirtualServer,
88 rand: GlobalRng,
89 time: TimeHandle,
90 hooks_req: Mutex<HashMap<NodeId, MsgHookFn>>,
91 hooks_rsp: Mutex<HashMap<NodeId, MsgHookFn>>,
92}
93
94pub type Payload = Box<dyn Any + Send + Sync>;
96
97type MsgHookFn = Arc<dyn Fn(&Payload) -> bool + Send + Sync>;
98
99impl plugin::Simulator for NetSim {
100 fn new(_rand: &GlobalRng, _time: &TimeHandle, _config: &crate::Config) -> Self {
101 unreachable!()
102 }
103
104 fn new1(rand: &GlobalRng, time: &TimeHandle, _task: &Spawner, config: &crate::Config) -> Self {
105 NetSim {
106 network: Mutex::new(Network::new(rand.clone(), config.net.clone())),
107 dns: Mutex::new(DnsServer::default()),
108 ipvs: IpVirtualServer::default(),
109 rand: rand.clone(),
110 time: time.clone(),
111 hooks_req: Default::default(),
112 hooks_rsp: Default::default(),
113 }
114 }
115
116 fn create_node(&self, id: NodeId) {
117 let mut network = self.network.lock();
118 network.insert_node(id);
119 }
120
121 fn reset_node(&self, id: NodeId) {
122 self.reset_node(id);
123 }
124}
125
126impl NetSim {
127 pub fn current() -> Arc<Self> {
129 plugin::simulator()
130 }
131
132 pub fn stat(&self) -> Stat {
134 self.network.lock().stat().clone()
135 }
136
137 pub fn update_config(&self, f: impl FnOnce(&mut Config)) {
139 let mut network = self.network.lock();
140 network.update_config(f);
141 }
142
143 pub fn reset_node(&self, id: NodeId) {
147 let mut network = self.network.lock();
148 network.reset_node(id);
149 }
150
151 pub fn set_ip(&self, node: NodeId, ip: IpAddr) {
153 let mut network = self.network.lock();
154 network.set_ip(node, ip);
155 }
156
157 #[deprecated(since = "0.3.0", note = "use `unclog_node` instead")]
159 pub fn connect(&self, id: NodeId) {
160 self.unclog_node(id);
161 }
162
163 pub fn unclog_node(&self, id: NodeId) {
165 self.network.lock().unclog_node(id, Direction::Both);
166 }
167
168 pub fn unclog_node_in(&self, id: NodeId) {
170 self.network.lock().unclog_node(id, Direction::In);
171 }
172
173 pub fn unclog_node_out(&self, id: NodeId) {
175 self.network.lock().unclog_node(id, Direction::Out);
176 }
177
178 #[deprecated(since = "0.3.0", note = "use `clog_node` instead")]
180 pub fn disconnect(&self, id: NodeId) {
181 self.clog_node(id);
182 }
183
184 pub fn clog_node(&self, id: NodeId) {
186 self.network.lock().clog_node(id, Direction::Both);
187 }
188
189 pub fn clog_node_in(&self, id: NodeId) {
191 self.network.lock().clog_node(id, Direction::In);
192 }
193
194 pub fn clog_node_out(&self, id: NodeId) {
196 self.network.lock().clog_node(id, Direction::Out);
197 }
198
199 #[deprecated(since = "0.3.0", note = "call `unclog_link` twice instead")]
201 pub fn connect2(&self, node1: NodeId, node2: NodeId) {
202 let mut network = self.network.lock();
203 network.unclog_link(node1, node2);
204 network.unclog_link(node2, node1);
205 }
206
207 pub fn unclog_link(&self, src: NodeId, dst: NodeId) {
209 self.network.lock().unclog_link(src, dst);
210 }
211
212 #[deprecated(since = "0.3.0", note = "call `clog_link` twice instead")]
214 pub fn disconnect2(&self, node1: NodeId, node2: NodeId) {
215 let mut network = self.network.lock();
216 network.clog_link(node1, node2);
217 network.clog_link(node2, node1);
218 }
219
220 pub fn clog_link(&self, src: NodeId, dst: NodeId) {
222 self.network.lock().clog_link(src, dst);
223 }
224
225 pub fn add_dns_record(&self, hostname: &str, ip: IpAddr) {
227 self.dns.lock().add(hostname, ip);
228 }
229
230 pub(crate) fn lookup_host(&self, hostname: &str) -> Option<IpAddr> {
232 self.dns.lock().lookup(hostname)
233 }
234
235 pub fn global_ipvs(&self) -> &IpVirtualServer {
237 &self.ipvs
238 }
239
240 #[cfg(feature = "rpc")]
244 #[cfg_attr(docsrs, doc(cfg(feature = "rpc")))]
245 pub fn hook_rpc_req<R: 'static>(
246 &self,
247 node: NodeId,
248 f: impl Fn(&R) -> bool + Send + Sync + 'static,
249 ) {
250 self.hooks_req.lock().insert(
251 node,
252 Arc::new(move |payload| {
253 if let Some((_, payload)) = payload.downcast_ref::<(u64, Payload)>() {
254 if let Some((_, msg, _)) = payload.downcast_ref::<(u64, R, Bytes)>() {
255 return f(msg);
256 }
257 }
258 true
259 }),
260 );
261 }
262
263 #[cfg(feature = "rpc")]
267 #[cfg_attr(docsrs, doc(cfg(feature = "rpc")))]
268 pub fn hook_rpc_rsp<R: 'static>(
269 &self,
270 node: NodeId,
271 f: impl Fn(&R) -> bool + Send + Sync + 'static,
272 ) {
273 self.hooks_rsp.lock().insert(
274 node,
275 Arc::new(move |payload| {
276 if let Some((_, payload)) = payload.downcast_ref::<(u64, Payload)>() {
277 if let Some((msg, _)) = payload.downcast_ref::<(R, Bytes)>() {
278 return f(msg);
279 }
280 }
281 true
282 }),
283 );
284 }
285
286 async fn rand_delay(&self) -> io::Result<()> {
288 let mut delay = Duration::from_micros(self.rand.with(|rng| rng.gen_range(0..5)));
289 if buggify_with_prob(0.1) {
290 delay = Duration::from_secs(self.rand.with(|rng| rng.gen_range(1..5)));
291 }
292 self.time.sleep(delay).await;
293 Ok(())
295 }
296
297 pub(crate) async fn send(
299 &self,
300 node: NodeId,
301 port: u16,
302 mut dst: SocketAddr,
303 protocol: IpProtocol,
304 msg: Payload,
305 ) -> io::Result<()> {
306 self.rand_delay().await?;
307 if let Some(hook) = self.hooks_req.lock().get(&node).cloned() {
308 if !hook(&msg) {
309 return Ok(());
310 }
311 }
312 if let Some(addr) = self
313 .ipvs
314 .get_server(ServiceAddr::from_addr_proto(dst, protocol))
315 {
316 dst = addr.parse().expect("invalid socket address");
317 }
318 if let Some((ip, dst_node, socket, latency)) =
319 self.network.lock().try_send(node, dst, protocol)
320 {
321 trace!(?latency, "delay");
322 let hook = self.hooks_rsp.lock().get(&dst_node).cloned();
323 self.time.add_timer(latency, move || {
324 if let Some(hook) = hook {
325 if !hook(&msg) {
326 return;
327 }
328 }
329 socket.deliver((ip, port).into(), dst, msg);
330 });
331 }
332 Ok(())
333 }
334
335 pub(crate) async fn connect1(
338 self: &Arc<Self>,
339 node: NodeId,
340 port: u16,
341 mut dst: SocketAddr,
342 protocol: IpProtocol,
343 ) -> io::Result<(PayloadSender, PayloadReceiver, SocketAddr)> {
344 self.rand_delay().await?;
345 if let Some(addr) = self
346 .ipvs
347 .get_server(ServiceAddr::from_addr_proto(dst, protocol))
348 {
349 dst = addr.parse().expect("invalid socket address");
350 }
351 let (ip, dst_node, socket, latency) = (self.network.lock().try_send(node, dst, protocol))
352 .ok_or_else(|| {
353 io::Error::new(io::ErrorKind::ConnectionRefused, "connection refused")
354 })?;
355 let src = (ip, port).into();
356 let (tx1, rx1) = self.channel(node, dst, protocol);
357 let (tx2, rx2) = self.channel(dst_node, src, protocol);
358 trace!(?latency, "delay");
359 socket.new_connection(src, dst, tx2, rx1);
362 Ok((tx1, rx2, src))
364 }
365
366 fn channel(
368 self: &Arc<Self>,
369 node: NodeId,
370 dst: SocketAddr,
371 protocol: IpProtocol,
372 ) -> (PayloadSender, PayloadReceiver) {
373 let (tx, mut rx) = mpsc::unbounded_channel();
374 let net = self.clone();
375 let test_link = Arc::new(move || {
376 net.network
377 .lock()
378 .try_send(node, dst, protocol)
379 .map(|(_, _, _, latency)| net.time.now_instant() + latency)
380 });
381 let sender = PayloadSender {
382 test_link: test_link.clone(),
383 tx,
384 };
385 let recver = async_stream::stream! {
386 while let Some((value, mut state)) = rx.recv().await {
387 let mut backoff = Duration::from_millis(1);
389 let arrive_time = loop {
390 if let Some(arrive_time) = state {
391 break arrive_time;
392 }
393 sleep(backoff).await;
395 backoff = (backoff * 2).min(Duration::from_secs(10));
396 state = test_link();
398 };
399 sleep_until(arrive_time).await;
400 yield value;
401 }
402 }
403 .boxed();
404 (sender, recver)
405 }
406}
407
408#[doc(hidden)]
409pub struct PayloadSender {
410 test_link: Arc<dyn Fn() -> State + Send + Sync>,
411 tx: mpsc::UnboundedSender<(Payload, State)>,
412}
413
414type State = Option<Instant>;
416
417impl PayloadSender {
418 fn send(&self, value: Payload) -> Option<()> {
419 let state = (self.test_link)();
420 self.tx.send((value, state)).ok()
421 }
422
423 fn is_closed(&self) -> bool {
424 self.tx.is_closed()
425 }
426
427 async fn closed(&self) {
428 self.tx.closed().await;
429 }
430}
431
432#[doc(hidden)]
433pub type PayloadReceiver = BoxStream<'static, Payload>;
434
435pub(crate) struct BindGuard {
437 net: Arc<NetSim>,
438 node: Arc<NodeInfo>,
439 addr: SocketAddr,
441 protocol: IpProtocol,
442}
443
444impl BindGuard {
445 pub async fn bind(
447 addr: impl ToSocketAddrs,
448 protocol: IpProtocol,
449 socket: Arc<dyn Socket>,
450 ) -> io::Result<Self> {
451 let net = plugin::simulator::<NetSim>();
452 let node = crate::context::current_task().node.clone();
453
454 let mut last_err = None;
456 for addr in lookup_host(addr).await? {
457 net.rand_delay().await?;
458 match net
459 .network
460 .lock()
461 .bind(node.id, addr, protocol, socket.clone())
462 {
463 Ok(addr) => {
464 return Ok(BindGuard {
465 net: net.clone(),
466 node,
467 addr,
468 protocol,
469 })
470 }
471 Err(e) => last_err = Some(e),
472 }
473 }
474 Err(last_err.unwrap_or_else(|| {
475 io::Error::new(
476 io::ErrorKind::InvalidInput,
477 "could not resolve to any addresses",
478 )
479 }))
480 }
481}
482
483impl Drop for BindGuard {
484 fn drop(&mut self) {
485 if self.node.is_killed() {
487 return;
488 }
489 if let Some(mut network) = self.net.network.try_lock() {
491 network.close(self.node.id, self.addr, self.protocol);
492 }
493 }
494}