mod common;
use common::{get_model, model_dir};
use std::{
sync::mpsc::{sync_channel, Receiver, SyncSender},
thread::spawn,
};
use bridgestan::Model;
const EXCLUDED_MODELS: [&str; 4] = ["logistic", "regression", "syntax_error", "external"];
#[test]
fn create_all_serial() {
let base = model_dir();
for path in base.read_dir().unwrap() {
let path = path.unwrap().path();
let name = path.file_name().unwrap().to_str().unwrap();
if EXCLUDED_MODELS.contains(&name) {
continue;
}
let (lib, data) = get_model(name);
let Ok(model) = Model::new(&lib, data.as_ref(), 42) else {
assert!((name == "ode") | (name == "throw_data"));
continue;
};
assert!(model.name().unwrap().contains(name));
}
}
#[test]
fn create_all_late_drop_fwd() {
let base = model_dir();
let names: Vec<String> = base
.read_dir()
.unwrap()
.map(|path| {
let path = path.unwrap().path();
path.file_name().unwrap().to_str().unwrap().to_string()
})
.collect();
let handles: Vec<_> = names
.into_iter()
.filter(|name| !EXCLUDED_MODELS.contains(&name.as_str()))
.map(|name| {
let (lib, data) = get_model(&name);
let Ok(model) = Model::new(&lib, data.as_ref(), 42) else {
assert!((name == "ode") | (name == "throw_data"));
return lib;
};
assert!(model.name().unwrap().contains(&name));
drop(model);
lib
})
.collect();
handles.into_iter().for_each(drop)
}
#[test]
fn create_all_thread_serial() {
let base = model_dir();
let names: Vec<String> = base
.read_dir()
.unwrap()
.map(|path| {
let path = path.unwrap().path();
path.file_name().unwrap().to_str().unwrap().to_string()
})
.collect();
names.into_iter().for_each(|name| {
spawn(move || {
if EXCLUDED_MODELS.contains(&name.as_str()) {
return;
}
let (lib, data) = get_model(&name);
let Ok(model) = Model::new(&lib, data.as_ref(), 42) else {
assert!((name == "ode") | (name == "throw_data"));
return;
};
assert!(model.name().unwrap().contains(&name));
})
.join()
.unwrap()
})
}
#[test]
fn create_all_parallel() {
let base = model_dir();
let names: Vec<String> = base
.read_dir()
.unwrap()
.map(|path| {
let path = path.unwrap().path();
path.file_name().unwrap().to_str().unwrap().to_string()
})
.collect();
let handles: Vec<_> = names
.into_iter()
.map(|name| {
spawn(move || {
if EXCLUDED_MODELS.contains(&name.as_str()) {
return;
}
let (lib, data) = get_model(&name);
let Ok(model) = Model::new(&lib, data.as_ref(), 42) else {
assert!((name == "ode") | (name == "throw_data"));
return;
};
assert!(model.name().unwrap().contains(&name));
})
})
.collect();
handles
.into_iter()
.for_each(|handle| handle.join().unwrap())
}
#[test]
fn load_after_unload_diff() {
let (lib1, _) = get_model("throw_data");
drop(lib1);
let (lib2, _) = get_model("stdnormal");
drop(lib2);
}
#[test]
fn load_after_unload_same() {
let (lib1, data1) = get_model("throw_data");
let Err(_) = Model::new(&lib1, data1, 42) else {
panic!("Did not return error")
};
drop(lib1);
let (lib2, data2) = get_model("throw_data");
let Err(_) = Model::new(&lib2, data2, 42) else {
panic!("Did not return error")
};
drop(lib2);
}
#[test]
fn load_twice_diff() {
let (lib1, _) = get_model("throw_data");
let (lib2, _) = get_model("stdnormal");
drop(lib1);
drop(lib2);
}
#[test]
fn load_twice_reorder_diff() {
let (lib1, _) = get_model("throw_data");
let (lib2, _) = get_model("stdnormal");
drop(lib2);
drop(lib1);
}
#[test]
fn load_twice_same() {
let (lib1, data1) = get_model("throw_data");
let (lib2, data2) = get_model("throw_data");
let Err(_) = Model::new(&lib1, data1, 42) else {
panic!("Did not return error")
};
let Err(_) = Model::new(&lib2, data2, 42) else {
panic!("Did not return error")
};
drop(lib1);
drop(lib2);
}
#[test]
fn load_order_all_serial() {
let (lib1, _) = get_model("bernoulli");
let (lib2, _) = get_model("fr_gaussian");
let (lib3, _) = get_model("full");
let (lib4, _) = get_model("gaussian");
drop(lib1);
let (lib5, _) = get_model("jacobian");
drop(lib2);
drop(lib3);
drop(lib4);
drop(lib5);
}
#[test]
fn load_order_min_parallel() {
let names = ["bernoulli", "gaussian", "jacobian"];
let (senders, handles): (Vec<_>, Vec<_>) = names
.into_iter()
.map(|name| {
let (load_sender, load_receiver) = sync_channel::<()>(0);
let (unload_sender, unload_receiver) = sync_channel::<()>(0);
let (exit_sender, exit_receiver) = sync_channel::<()>(0);
let (ok_sender, ok_receiver) = sync_channel::<()>(0);
let handle = spawn(move || {
load_receiver.recv().unwrap();
let (lib, _) = get_model(name);
ok_sender.send(()).unwrap();
unload_receiver.recv().unwrap();
unsafe { lib.unload_library() };
ok_sender.send(()).unwrap();
exit_receiver.recv().unwrap();
});
(
(load_sender, unload_sender, exit_sender, ok_receiver),
handle,
)
})
.unzip();
fn load(s: &(SyncSender<()>, SyncSender<()>, SyncSender<()>, Receiver<()>)) {
s.0.send(()).unwrap();
s.3.recv().unwrap();
}
fn unload(s: &(SyncSender<()>, SyncSender<()>, SyncSender<()>, Receiver<()>)) {
s.1.send(()).unwrap();
s.3.recv().unwrap();
}
fn exit(s: &(SyncSender<()>, SyncSender<()>, SyncSender<()>, Receiver<()>)) {
s.2.send(()).unwrap();
}
load(&senders[0]);
load(&senders[1]);
unload(&senders[0]);
exit(&senders[0]);
load(&senders[2]);
unload(&senders[1]);
exit(&senders[1]);
unload(&senders[2]);
exit(&senders[2]);
handles.into_iter().for_each(|h| h.join().unwrap());
}
#[test]
fn load_order_min_serial() {
let (lib1, _) = get_model("bernoulli");
let (lib2, _) = get_model("gaussian");
drop(lib1);
let (lib3, _) = get_model("jacobian");
drop(lib2);
drop(lib3);
}
#[test]
fn load_parallel() {
let handles: Vec<_> = (0..50)
.map(|_| {
spawn(|| {
let (lib1, data1) = get_model("throw_data");
let Err(_) = Model::new(&lib1, data1, 42) else {
panic!("Did not return error")
};
})
})
.collect();
handles
.into_iter()
.for_each(|handle| handle.join().unwrap())
}