use super::*;
use crate::task::{JoinHandle, NodeId, ToNodeId};
use spin::Mutex;
use std::{
any::{Any, TypeId},
collections::HashMap,
fmt,
future::Future,
net::IpAddr,
sync::Arc,
time::Duration,
};
mod builder;
pub(crate) mod context;
mod metrics;
pub use self::builder::Builder;
pub use self::metrics::RuntimeMetrics;
pub struct Runtime {
rand: rand::GlobalRng,
task: task::Executor,
handle: Handle,
}
impl Default for Runtime {
fn default() -> Self {
Self::new()
}
}
impl Runtime {
pub fn new() -> Self {
Self::with_seed_and_config(0, Config::default())
}
pub fn with_seed_and_config(seed: u64, config: Config) -> Self {
let rand = rand::GlobalRng::new_with_seed(seed);
let sims = Arc::new(Mutex::new(HashMap::new()));
let task = task::Executor::new(rand.clone(), sims.clone());
let handle = Handle {
rand: rand.clone(),
time: task.time_handle().clone(),
task: task.handle().clone(),
sims,
config,
allow_system_thread: false,
};
let rt = Runtime { rand, task, handle };
rt.add_simulator::<fs::FsSim>();
rt.add_simulator::<net::NetSim>();
rt
}
pub fn add_simulator<S: plugin::Simulator>(&self) {
let mut sims = self.handle.sims.lock();
let sim = Arc::new(S::new1(
&self.handle.rand,
&self.handle.time,
&self.handle.task.get_node(NodeId::zero()).unwrap(),
&self.handle.config,
));
sim.create_node(NodeId::zero());
sims.insert(TypeId::of::<S>(), sim);
}
pub fn handle(&self) -> &Handle {
&self.handle
}
pub fn create_node(&self) -> NodeBuilder<'_> {
self.handle.create_node()
}
pub fn block_on<F: Future>(&self, future: F) -> F::Output {
let _guard = crate::context::enter(self.handle.clone());
self.task.block_on(future)
}
pub fn set_time_limit(&mut self, limit: Duration) {
self.task.set_time_limit(limit);
}
pub fn set_allow_system_thread(&mut self, allowed: bool) {
self.handle.allow_system_thread = allowed;
}
pub fn check_determinism<F>(seed: u64, config: Config, f: fn() -> F) -> F::Output
where
F: Future + 'static,
F::Output: Send,
{
let config0 = config.clone();
let log = std::thread::spawn(move || {
let rt = Runtime::with_seed_and_config(seed, config0);
rt.rand.enable_log();
rt.block_on(f());
rt.rand.take_log().unwrap()
})
.join()
.map_err(|e| panic_with_info(seed, e))
.unwrap();
std::thread::spawn(move || {
let rt = Runtime::with_seed_and_config(seed, config);
rt.rand.enable_check(log);
rt.block_on(f())
})
.join()
.map_err(|e| panic_with_info(seed, e))
.unwrap()
}
}
fn panic_with_info(seed: u64, payload: Box<dyn Any + Send>) -> ! {
eprintln!(
"note: run with `MADSIM_TEST_SEED={seed}` environment variable to reproduce this error"
);
std::panic::resume_unwind(payload);
}
#[derive(Clone)]
pub struct Handle {
pub(crate) rand: rand::GlobalRng,
pub(crate) time: time::TimeHandle,
pub(crate) task: task::TaskHandle,
pub(crate) sims: Arc<Simulators>,
pub(crate) config: Config,
pub(crate) allow_system_thread: bool,
}
impl fmt::Debug for Handle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Handle").finish()
}
}
pub(crate) type Simulators = Mutex<HashMap<TypeId, Arc<dyn plugin::Simulator>>>;
#[derive(Debug)]
pub struct TryCurrentError;
impl Handle {
pub fn current() -> Self {
context::current(|h| h.clone())
}
pub fn try_current() -> Result<Self, TryCurrentError> {
context::try_current(|h| h.clone()).ok_or(TryCurrentError)
}
pub fn seed(&self) -> u64 {
self.rand.seed()
}
pub fn kill(&self, id: impl ToNodeId) {
self.task.kill(&id);
}
pub fn restart(&self, id: impl ToNodeId) {
self.task.restart(&id);
}
pub fn pause(&self, id: impl ToNodeId) {
self.task.pause(id);
}
pub fn resume(&self, id: impl ToNodeId) {
self.task.resume(id);
}
pub fn send_ctrl_c(&self, id: impl ToNodeId) {
self.task.send_ctrl_c(id);
}
pub fn is_exit(&self, id: impl ToNodeId) -> bool {
self.task.is_exit(id)
}
pub fn create_node(&self) -> NodeBuilder<'_> {
NodeBuilder::new(self)
}
pub fn get_node(&self, id: impl ToNodeId) -> Option<NodeHandle> {
self.task.get_node(id).map(|task| NodeHandle { task })
}
pub fn metrics(&self) -> RuntimeMetrics {
RuntimeMetrics {
task: self.task.clone(),
}
}
}
pub struct NodeBuilder<'a> {
handle: &'a Handle,
pub(crate) name: Option<String>,
pub(crate) ip: Option<IpAddr>,
pub(crate) cores: Option<usize>,
pub(crate) init: Option<task::InitFn>,
pub(crate) restart_on_panic: bool,
pub(crate) restart_on_panic_matching: Vec<String>,
}
impl<'a> NodeBuilder<'a> {
fn new(handle: &'a Handle) -> Self {
NodeBuilder {
handle,
name: None,
ip: None,
cores: None,
init: None,
restart_on_panic: false,
restart_on_panic_matching: vec![],
}
}
pub fn name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn init<F>(mut self, new_task: impl Fn() -> F + Send + Sync + 'static) -> Self
where
F: Future + 'static,
{
self.init = Some(Arc::new(move |handle| {
let future = new_task();
let h = handle.clone();
handle.spawn_local(async move {
future.await;
h.exit();
});
}));
self
}
pub fn restart_on_panic(mut self) -> Self {
self.restart_on_panic = true;
self
}
pub fn restart_on_panic_matching(mut self, msg: impl Into<String>) -> Self {
self.restart_on_panic_matching.push(msg.into());
self
}
pub fn ip(mut self, ip: IpAddr) -> Self {
self.ip = Some(ip);
self
}
pub fn cores(mut self, cores: usize) -> Self {
assert_ne!(cores, 0, "cores must be greater than 0");
self.cores = Some(cores);
self
}
pub fn build(self) -> NodeHandle {
let task = self.handle.task.create_node(&self);
let sims = self.handle.sims.lock();
let values = sims.values();
for sim in values {
sim.create_node(task.node_id());
if let Some(ip) = self.ip {
if let Some(net) = sim.downcast_ref::<net::NetSim>() {
net.set_ip(task.node_id(), ip)
}
}
}
NodeHandle { task }
}
}
#[derive(Clone)]
pub struct NodeHandle {
task: task::Spawner,
}
impl NodeHandle {
pub fn id(&self) -> NodeId {
self.task.node_id()
}
#[track_caller]
pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.task.spawn(future)
}
}
pub fn init_logger() {
use std::sync::Once;
static LOGGER_INIT: Once = Once::new();
LOGGER_INIT.call_once(tracing_subscriber::fmt::init);
}