use bytes::Bytes;
use futures_util::{stream::BoxStream, StreamExt};
use spin::Mutex;
use std::{
any::Any,
collections::HashMap,
io,
net::{IpAddr, SocketAddr},
sync::Arc,
time::Instant,
};
use tokio::sync::{mpsc, oneshot};
use tracing::*;
use crate::{
buggify::buggify_with_prob,
plugin,
rand::{GlobalRng, Rng},
task::{NodeId, NodeInfo, Spawner},
time::{sleep, sleep_until, Duration, TimeHandle},
};
mod addr;
mod dns;
mod endpoint;
pub mod ipvs;
mod network;
#[cfg(feature = "rpc")]
#[cfg_attr(docsrs, doc(cfg(feature = "rpc")))]
pub mod rpc;
pub mod tcp;
mod udp;
pub mod unix;
pub use self::addr::{lookup_host, ToSocketAddrs};
use self::dns::DnsServer;
pub use self::endpoint::{Endpoint, Receiver, Sender};
use self::ipvs::{IpVirtualServer, ServiceAddr};
pub use self::network::{Config, Stat};
use self::network::{Direction, IpProtocol, Network, Socket};
pub use self::tcp::{TcpListener, TcpStream};
pub use self::udp::UdpSocket;
pub use self::unix::{UnixDatagram, UnixListener, UnixStream};
#[cfg_attr(docsrs, doc(cfg(madsim)))]
pub struct NetSim {
network: Mutex<Network>,
dns: Mutex<DnsServer>,
ipvs: IpVirtualServer,
rand: GlobalRng,
time: TimeHandle,
hooks_req: Mutex<HashMap<NodeId, MsgHookFn>>,
hooks_rsp: Mutex<HashMap<NodeId, MsgHookFn>>,
}
pub type Payload = Box<dyn Any + Send + Sync>;
type MsgHookFn = Arc<dyn Fn(&Payload) -> bool + Send + Sync>;
impl plugin::Simulator for NetSim {
fn new(_rand: &GlobalRng, _time: &TimeHandle, _config: &crate::Config) -> Self {
unreachable!()
}
fn new1(rand: &GlobalRng, time: &TimeHandle, _task: &Spawner, config: &crate::Config) -> Self {
NetSim {
network: Mutex::new(Network::new(rand.clone(), config.net.clone())),
dns: Mutex::new(DnsServer::default()),
ipvs: IpVirtualServer::default(),
rand: rand.clone(),
time: time.clone(),
hooks_req: Default::default(),
hooks_rsp: Default::default(),
}
}
fn create_node(&self, id: NodeId) {
let mut network = self.network.lock();
network.insert_node(id);
}
fn reset_node(&self, id: NodeId) {
self.reset_node(id);
}
}
impl NetSim {
pub fn current() -> Arc<Self> {
plugin::simulator()
}
pub fn stat(&self) -> Stat {
self.network.lock().stat().clone()
}
pub fn update_config(&self, f: impl FnOnce(&mut Config)) {
let mut network = self.network.lock();
network.update_config(f);
}
pub fn reset_node(&self, id: NodeId) {
let mut network = self.network.lock();
network.reset_node(id);
}
pub fn set_ip(&self, node: NodeId, ip: IpAddr) {
let mut network = self.network.lock();
network.set_ip(node, ip);
}
#[deprecated(since = "0.3.0", note = "use `unclog_node` instead")]
pub fn connect(&self, id: NodeId) {
self.unclog_node(id);
}
pub fn unclog_node(&self, id: NodeId) {
self.network.lock().unclog_node(id, Direction::Both);
}
pub fn unclog_node_in(&self, id: NodeId) {
self.network.lock().unclog_node(id, Direction::In);
}
pub fn unclog_node_out(&self, id: NodeId) {
self.network.lock().unclog_node(id, Direction::Out);
}
#[deprecated(since = "0.3.0", note = "use `clog_node` instead")]
pub fn disconnect(&self, id: NodeId) {
self.clog_node(id);
}
pub fn clog_node(&self, id: NodeId) {
self.network.lock().clog_node(id, Direction::Both);
}
pub fn clog_node_in(&self, id: NodeId) {
self.network.lock().clog_node(id, Direction::In);
}
pub fn clog_node_out(&self, id: NodeId) {
self.network.lock().clog_node(id, Direction::Out);
}
#[deprecated(since = "0.3.0", note = "call `unclog_link` twice instead")]
pub fn connect2(&self, node1: NodeId, node2: NodeId) {
let mut network = self.network.lock();
network.unclog_link(node1, node2);
network.unclog_link(node2, node1);
}
pub fn unclog_link(&self, src: NodeId, dst: NodeId) {
self.network.lock().unclog_link(src, dst);
}
#[deprecated(since = "0.3.0", note = "call `clog_link` twice instead")]
pub fn disconnect2(&self, node1: NodeId, node2: NodeId) {
let mut network = self.network.lock();
network.clog_link(node1, node2);
network.clog_link(node2, node1);
}
pub fn clog_link(&self, src: NodeId, dst: NodeId) {
self.network.lock().clog_link(src, dst);
}
pub fn add_dns_record(&self, hostname: &str, ip: IpAddr) {
self.dns.lock().add(hostname, ip);
}
pub(crate) fn lookup_host(&self, hostname: &str) -> Option<IpAddr> {
self.dns.lock().lookup(hostname)
}
pub fn global_ipvs(&self) -> &IpVirtualServer {
&self.ipvs
}
#[cfg(feature = "rpc")]
#[cfg_attr(docsrs, doc(cfg(feature = "rpc")))]
pub fn hook_rpc_req<R: 'static>(
&self,
node: NodeId,
f: impl Fn(&R) -> bool + Send + Sync + 'static,
) {
self.hooks_req.lock().insert(
node,
Arc::new(move |payload| {
if let Some((_, payload)) = payload.downcast_ref::<(u64, Payload)>() {
if let Some((_, msg, _)) = payload.downcast_ref::<(u64, R, Bytes)>() {
return f(msg);
}
}
true
}),
);
}
#[cfg(feature = "rpc")]
#[cfg_attr(docsrs, doc(cfg(feature = "rpc")))]
pub fn hook_rpc_rsp<R: 'static>(
&self,
node: NodeId,
f: impl Fn(&R) -> bool + Send + Sync + 'static,
) {
self.hooks_rsp.lock().insert(
node,
Arc::new(move |payload| {
if let Some((_, payload)) = payload.downcast_ref::<(u64, Payload)>() {
if let Some((msg, _)) = payload.downcast_ref::<(R, Bytes)>() {
return f(msg);
}
}
true
}),
);
}
async fn rand_delay(&self) -> io::Result<()> {
let mut delay = Duration::from_micros(self.rand.with(|rng| rng.gen_range(0..5)));
if buggify_with_prob(0.1) {
delay = Duration::from_secs(self.rand.with(|rng| rng.gen_range(1..5)));
}
self.time.sleep(delay).await;
Ok(())
}
pub(crate) async fn send(
&self,
node: NodeId,
port: u16,
mut dst: SocketAddr,
protocol: IpProtocol,
msg: Payload,
) -> io::Result<()> {
self.rand_delay().await?;
if let Some(hook) = self.hooks_req.lock().get(&node).cloned() {
if !hook(&msg) {
return Ok(());
}
}
if let Some(addr) = self
.ipvs
.get_server(ServiceAddr::from_addr_proto(dst, protocol))
{
dst = addr.parse().expect("invalid socket address");
}
if let Some((ip, dst_node, socket, latency)) =
self.network.lock().try_send(node, dst, protocol)
{
trace!(?latency, "delay");
let hook = self.hooks_rsp.lock().get(&dst_node).cloned();
self.time.add_timer(latency, move || {
if let Some(hook) = hook {
if !hook(&msg) {
return;
}
}
socket.deliver((ip, port).into(), dst, msg);
});
}
Ok(())
}
pub(crate) async fn connect1(
self: &Arc<Self>,
node: NodeId,
port: u16,
mut dst: SocketAddr,
protocol: IpProtocol,
) -> io::Result<(PayloadSender, PayloadReceiver, SocketAddr)> {
self.rand_delay().await?;
if let Some(addr) = self
.ipvs
.get_server(ServiceAddr::from_addr_proto(dst, protocol))
{
dst = addr.parse().expect("invalid socket address");
}
let (ip, dst_node, socket, latency) = (self.network.lock().try_send(node, dst, protocol))
.ok_or_else(|| {
io::Error::new(io::ErrorKind::ConnectionRefused, "connection refused")
})?;
let src = (ip, port).into();
let (tx1, rx1) = self.channel(node, dst, protocol);
let (tx2, rx2) = self.channel(dst_node, src, protocol);
trace!(?latency, "delay");
socket.new_connection(src, dst, tx2, rx1);
Ok((tx1, rx2, src))
}
fn channel(
self: &Arc<Self>,
node: NodeId,
dst: SocketAddr,
protocol: IpProtocol,
) -> (PayloadSender, PayloadReceiver) {
let (tx, mut rx) = mpsc::unbounded_channel();
let net = self.clone();
let test_link = Arc::new(move || {
net.network
.lock()
.try_send(node, dst, protocol)
.map(|(_, _, _, latency)| net.time.now_instant() + latency)
});
let sender = PayloadSender {
test_link: test_link.clone(),
tx,
};
let recver = async_stream::stream! {
while let Some((value, mut state)) = rx.recv().await {
let mut backoff = Duration::from_millis(1);
let arrive_time = loop {
if let Some(arrive_time) = state {
break arrive_time;
}
sleep(backoff).await;
backoff = (backoff * 2).min(Duration::from_secs(10));
state = test_link();
};
sleep_until(arrive_time).await;
yield value;
}
}
.boxed();
(sender, recver)
}
}
#[doc(hidden)]
pub struct PayloadSender {
test_link: Arc<dyn Fn() -> State + Send + Sync>,
tx: mpsc::UnboundedSender<(Payload, State)>,
}
type State = Option<Instant>;
impl PayloadSender {
fn send(&self, value: Payload) -> Option<()> {
let state = (self.test_link)();
self.tx.send((value, state)).ok()
}
fn is_closed(&self) -> bool {
self.tx.is_closed()
}
async fn closed(&self) {
self.tx.closed().await;
}
}
#[doc(hidden)]
pub type PayloadReceiver = BoxStream<'static, Payload>;
pub(crate) struct BindGuard {
net: Arc<NetSim>,
node: Arc<NodeInfo>,
addr: SocketAddr,
protocol: IpProtocol,
}
impl BindGuard {
pub async fn bind(
addr: impl ToSocketAddrs,
protocol: IpProtocol,
socket: Arc<dyn Socket>,
) -> io::Result<Self> {
let net = plugin::simulator::<NetSim>();
let node = crate::context::current_task().node.clone();
let mut last_err = None;
for addr in lookup_host(addr).await? {
net.rand_delay().await?;
match net
.network
.lock()
.bind(node.id, addr, protocol, socket.clone())
{
Ok(addr) => {
return Ok(BindGuard {
net: net.clone(),
node,
addr,
protocol,
})
}
Err(e) => last_err = Some(e),
}
}
Err(last_err.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any addresses",
)
}))
}
}
impl Drop for BindGuard {
fn drop(&mut self) {
if self.node.is_killed() {
return;
}
if let Some(mut network) = self.net.network.try_lock() {
network.close(self.node.id, self.addr, self.protocol);
}
}
}