use crate::ffi;
use crate::VERSION;
use std::borrow::Borrow;
use std::collections::hash_map::DefaultHasher;
use std::ffi::c_char;
use std::ffi::c_int;
use std::ffi::c_uint;
use std::ffi::CStr;
use std::ffi::OsStr;
use std::hash::Hash;
use std::hash::Hasher;
use std::mem::forget;
use std::mem::ManuallyDrop;
use std::ptr::NonNull;
use std::ptr::{null, null_mut};
use std::str::Utf8Error;
use std::time::Instant;
use thiserror::Error;
pub struct StanLibrary {
lib: ManuallyDrop<ffi::BridgeStan>,
id: u64,
}
unsafe impl Send for StanLibrary {}
unsafe impl Sync for StanLibrary {}
impl Drop for StanLibrary {
fn drop(&mut self) {
let lib = unsafe { ManuallyDrop::take(&mut self.lib) };
forget(lib.into_library());
}
}
pub type StanPrintCallback = extern "C" fn(*const c_char, usize);
impl StanLibrary {
pub unsafe fn set_print_callback(&mut self, callback: StanPrintCallback) -> Result<()> {
let mut err = ErrorMsg::new(self);
let rc = unsafe { self.lib.bs_set_print_callback(Some(callback), err.as_ptr()) };
if rc == 0 {
Ok(())
} else {
Err(BridgeStanError::SetCallbackFailed(err.message()))
}
}
pub unsafe fn unload_library(mut self) {
let lib = unsafe { ManuallyDrop::take(&mut self.lib) };
drop(lib.into_library());
forget(self);
}
}
#[derive(Error, Debug)]
#[error("Could not load target library: {0}")]
pub struct LoadingError(#[from] libloading::Error);
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum BridgeStanError {
#[error(transparent)]
InvalidLibrary(#[from] LoadingError),
#[error("Bad Stan library version: Got {0} but expected {1}")]
BadLibraryVersion(String, String),
#[error("The Stan library was compiled without threading support. Config was {0}")]
StanThreads(String),
#[error("Failed to decode string to UTF8")]
InvalidString(#[from] Utf8Error),
#[error("Failed to construct model: {0}")]
ConstructFailed(String),
#[error("Failed during evaluation: {0}")]
EvaluationFailed(String),
#[error("Failed to set a print-callback: {0}")]
SetCallbackFailed(String),
#[error("Failed to compile Stan model: {0}")]
ModelCompilingFailed(String),
#[error("Failed to download BridgeStan {VERSION} from github.com: {0}")]
DownloadFailed(String),
}
pub(crate) type Result<T> = std::result::Result<T, BridgeStanError>;
pub fn open_library<P: AsRef<OsStr>>(path: P) -> Result<StanLibrary> {
let library = unsafe { libloading::Library::new(&path) }.map_err(LoadingError)?;
let major: libloading::Symbol<*const c_int> =
unsafe { library.get(b"bs_major_version") }.map_err(LoadingError)?;
let major = unsafe { **major };
let minor: libloading::Symbol<*const c_int> =
unsafe { library.get(b"bs_minor_version") }.map_err(LoadingError)?;
let minor = unsafe { **minor };
let patch: libloading::Symbol<*const c_int> =
unsafe { library.get(b"bs_patch_version") }.map_err(LoadingError)?;
let patch = unsafe { **patch };
let self_major: c_int = env!("CARGO_PKG_VERSION_MAJOR").parse().unwrap();
let self_minor: c_int = env!("CARGO_PKG_VERSION_MINOR").parse().unwrap();
let self_patch: c_int = env!("CARGO_PKG_VERSION_PATCH").parse().unwrap();
if !((self_major == major) & (self_minor <= minor)) {
return Err(BridgeStanError::BadLibraryVersion(
format!("{major}.{minor}.{patch}"),
format!("{self_major}.{self_minor}.{self_patch}"),
));
}
let lib = unsafe { ffi::BridgeStan::from_library(library) }.map_err(LoadingError)?;
let lib = ManuallyDrop::new(lib);
let mut hasher = DefaultHasher::new();
Instant::now().hash(&mut hasher);
path.as_ref().hash(&mut hasher);
let id = hasher.finish();
Ok(StanLibrary { lib, id })
}
pub struct Model<T: Borrow<StanLibrary>> {
model: NonNull<ffi::bs_model>,
lib: T,
}
unsafe impl<T: Sync + Borrow<StanLibrary>> Sync for Model<T> {}
unsafe impl<T: Send + Borrow<StanLibrary>> Send for Model<T> {}
pub struct Rng<T: Borrow<StanLibrary>> {
rng: NonNull<ffi::bs_rng>,
lib: T,
}
unsafe impl<T: Sync + Borrow<StanLibrary>> Sync for Rng<T> {}
unsafe impl<T: Send + Borrow<StanLibrary>> Send for Rng<T> {}
impl<T: Borrow<StanLibrary>> Drop for Rng<T> {
fn drop(&mut self) {
unsafe {
self.lib.borrow().lib.bs_rng_destruct(self.rng.as_ptr());
}
}
}
impl<T: Borrow<StanLibrary>> Rng<T> {
pub fn new(lib: T, seed: u32) -> Result<Self> {
let mut err = ErrorMsg::new(lib.borrow());
let rng = unsafe {
lib.borrow()
.lib
.bs_rng_construct(seed as c_uint, err.as_ptr())
};
if let Some(rng) = NonNull::new(rng) {
drop(err);
Ok(Self { rng, lib })
} else {
Err(BridgeStanError::ConstructFailed(err.message()))
}
}
}
struct ErrorMsg<'lib> {
msg: *mut c_char,
lib: &'lib StanLibrary,
}
impl Drop for ErrorMsg<'_> {
fn drop(&mut self) {
if !self.msg.is_null() {
unsafe { self.lib.lib.bs_free_error_msg(self.msg) };
}
}
}
impl<'lib> ErrorMsg<'lib> {
fn new(lib: &'lib StanLibrary) -> Self {
Self {
msg: std::ptr::null_mut(),
lib,
}
}
fn as_ptr(&mut self) -> *mut *mut c_char {
&mut self.msg
}
fn message(&self) -> String {
NonNull::new(self.msg)
.map(|msg| {
unsafe { CStr::from_ptr(msg.as_ptr()) }
.to_string_lossy()
.to_string()
})
.expect("Stan returned an error but no error message")
}
}
impl<T: Borrow<StanLibrary>> Model<T> {
fn ffi_lib(&self) -> &ffi::BridgeStan {
&self.lib.borrow().lib
}
pub fn new<D: AsRef<CStr>>(lib: T, data: Option<D>, seed: u32) -> Result<Self> {
let mut err = ErrorMsg::new(lib.borrow());
let data_ptr = data
.as_ref()
.map(|data| data.as_ref().as_ptr())
.unwrap_or(null());
let model = unsafe {
lib.borrow()
.lib
.bs_model_construct(data_ptr, seed, err.as_ptr())
};
drop(data);
if let Some(model) = NonNull::new(model) {
drop(err);
let model = Self { model, lib };
let info = model.info();
if !info.to_string_lossy().contains("STAN_THREADS=true") {
Err(BridgeStanError::StanThreads(
info.to_string_lossy().into_owned(),
))
} else {
Ok(model)
}
} else {
Err(BridgeStanError::ConstructFailed(err.message()))
}
}
pub fn ref_library(&self) -> &StanLibrary {
self.lib.borrow()
}
pub fn new_rng(&self, seed: u32) -> Result<Rng<&StanLibrary>> {
Rng::new(self.ref_library(), seed)
}
pub fn name(&self) -> Result<&str> {
let cstr = unsafe { CStr::from_ptr(self.ffi_lib().bs_name(self.model.as_ptr())) };
Ok(cstr.to_str()?)
}
pub fn info(&self) -> &CStr {
unsafe { CStr::from_ptr(self.ffi_lib().bs_model_info(self.model.as_ptr())) }
}
pub fn param_names(&self, include_tp: bool, include_gq: bool) -> &str {
let cstr = unsafe {
CStr::from_ptr(self.ffi_lib().bs_param_names(
self.model.as_ptr(),
include_tp,
include_gq,
))
};
cstr.to_str()
.expect("Stan model has invalid parameter names")
}
pub fn param_unc_names(&mut self) -> &str {
let cstr =
unsafe { CStr::from_ptr(self.ffi_lib().bs_param_unc_names(self.model.as_ptr())) };
cstr.to_str()
.expect("Stan model has invalid parameter names")
}
pub fn param_num(&self, include_tp: bool, include_gq: bool) -> usize {
unsafe {
self.ffi_lib()
.bs_param_num(self.model.as_ptr(), include_tp, include_gq)
}
.try_into()
.expect("Stan returned an invalid number of parameters")
}
pub fn param_unc_num(&self) -> usize {
unsafe { self.ffi_lib().bs_param_unc_num(self.model.as_ptr()) }
.try_into()
.expect("Stan returned an invalid number of parameters")
}
pub fn log_density(&self, theta_unc: &[f64], propto: bool, jacobian: bool) -> Result<f64> {
let n = self.param_unc_num();
assert_eq!(
theta_unc.len(),
n,
"Argument 'theta_unc' must be the same size as the number of parameters!"
);
let mut val = 0.0;
let mut err = ErrorMsg::new(self.lib.borrow());
let rc = unsafe {
self.ffi_lib().bs_log_density(
self.model.as_ptr(),
propto,
jacobian,
theta_unc.as_ptr(),
&mut val,
err.as_ptr(),
)
};
if rc == 0 {
Ok(val)
} else {
Err(BridgeStanError::EvaluationFailed(err.message()))
}
}
pub fn log_density_gradient(
&self,
theta_unc: &[f64],
propto: bool,
jacobian: bool,
grad: &mut [f64],
) -> Result<f64> {
let n = self.param_unc_num();
assert_eq!(
theta_unc.len(),
n,
"Argument 'theta_unc' must be the same size as the number of parameters!"
);
assert_eq!(
grad.len(),
n,
"Argument 'grad' must be the same size as the number of parameters!"
);
let mut val = 0.0;
let mut err = ErrorMsg::new(self.lib.borrow());
let rc = unsafe {
self.ffi_lib().bs_log_density_gradient(
self.model.as_ptr(),
propto,
jacobian,
theta_unc.as_ptr(),
&mut val,
grad.as_mut_ptr(),
err.as_ptr(),
)
};
if rc == 0 {
Ok(val)
} else {
Err(BridgeStanError::EvaluationFailed(err.message()))
}
}
pub fn log_density_hessian(
&self,
theta_unc: &[f64],
propto: bool,
jacobian: bool,
grad: &mut [f64],
hessian: &mut [f64],
) -> Result<f64> {
let n = self.param_unc_num();
assert_eq!(
theta_unc.len(),
n,
"Argument 'theta_unc' must be the same size as the number of parameters!"
);
assert_eq!(
grad.len(),
n,
"Argument 'grad' must be the same size as the number of parameters!"
);
assert_eq!(
hessian.len(),
n.checked_mul(n).expect("Overflow for size of hessian"),
"Argument 'hessian' must be the same size as the number of parameters squared!"
);
let mut val = 0.0;
let mut err = ErrorMsg::new(self.lib.borrow());
let rc = unsafe {
self.ffi_lib().bs_log_density_hessian(
self.model.as_ptr(),
propto,
jacobian,
theta_unc.as_ptr(),
&mut val,
grad.as_mut_ptr(),
hessian.as_mut_ptr(),
err.as_ptr(),
)
};
if rc == 0 {
Ok(val)
} else {
Err(BridgeStanError::EvaluationFailed(err.message()))
}
}
pub fn log_density_hessian_vector_product(
&self,
theta_unc: &[f64],
v: &[f64],
propto: bool,
jacobian: bool,
hvp: &mut [f64],
) -> Result<f64> {
let n = self.param_unc_num();
assert_eq!(
theta_unc.len(),
n,
"Argument 'theta_unc' must be the same size as the number of parameters!"
);
assert_eq!(
v.len(),
n,
"Argument 'v' must be the same size as the number of parameters!"
);
assert_eq!(
hvp.len(),
n,
"Argument 'hvp' must be the same size as the number of parameters!"
);
let mut val = 0.0;
let mut err = ErrorMsg::new(self.lib.borrow());
let rc = unsafe {
self.ffi_lib().bs_log_density_hessian_vector_product(
self.model.as_ptr(),
propto,
jacobian,
theta_unc.as_ptr(),
v.as_ptr(),
&mut val,
hvp.as_mut_ptr(),
err.as_ptr(),
)
};
if rc == 0 {
Ok(val)
} else {
Err(BridgeStanError::EvaluationFailed(err.message()))
}
}
pub fn param_constrain<R: Borrow<StanLibrary>>(
&self,
theta_unc: &[f64],
include_tp: bool,
include_gq: bool,
out: &mut [f64],
rng: Option<&mut Rng<R>>,
) -> Result<()> {
let n = self.param_unc_num();
assert_eq!(
theta_unc.len(),
n,
"Argument 'theta_unc' must be the same size as the number of parameters!"
);
let out_n = self.param_num(include_tp, include_gq);
assert_eq!(
out.len(),
out_n,
"Argument 'out' must be the same size as the number of parameters!"
);
if include_gq {
assert!(
rng.is_some(),
"Rng was not provided even though generated quantities are requested."
);
}
if let Some(rng) = &rng {
assert!(
rng.lib.borrow().id == self.lib.borrow().id,
"Rng and model must come from the same Stan library"
);
}
let mut err = ErrorMsg::new(self.lib.borrow());
let rc = unsafe {
self.ffi_lib().bs_param_constrain(
self.model.as_ptr(),
include_tp,
include_gq,
theta_unc.as_ptr(),
out.as_mut_ptr(),
rng.map(|rng| rng.rng.as_ptr()).unwrap_or(null_mut()),
err.as_ptr(),
)
};
if rc == 0 {
Ok(())
} else {
Err(BridgeStanError::EvaluationFailed(err.message()))
}
}
pub fn param_unconstrain(&self, theta: &[f64], theta_unc: &mut [f64]) -> Result<()> {
assert_eq!(
theta_unc.len(),
self.param_unc_num(),
"Argument 'theta_unc' must be the same size as the number of parameters!"
);
assert_eq!(
theta.len(),
self.param_num(false, false),
"Argument 'out' must be the same size as the number of parameters!"
);
let mut err = ErrorMsg::new(self.lib.borrow());
let rc = unsafe {
self.ffi_lib().bs_param_unconstrain(
self.model.as_ptr(),
theta.as_ptr(),
theta_unc.as_mut_ptr(),
err.as_ptr(),
)
};
if rc == 0 {
Ok(())
} else {
Err(BridgeStanError::EvaluationFailed(err.message()))
}
}
pub fn param_unconstrain_json<S: AsRef<CStr>>(
&self,
json: S,
theta_unc: &mut [f64],
) -> Result<()> {
assert_eq!(
theta_unc.len(),
self.param_unc_num(),
"Argument 'theta_unc' must be the same size as the number of parameters!"
);
let mut err = ErrorMsg::new(self.lib.borrow());
let rc = unsafe {
self.ffi_lib().bs_param_unconstrain_json(
self.model.as_ptr(),
json.as_ref().as_ptr(),
theta_unc.as_mut_ptr(),
err.as_ptr(),
)
};
if rc == 0 {
Ok(())
} else {
Err(BridgeStanError::EvaluationFailed(err.message()))
}
}
}
impl<T: Borrow<StanLibrary> + Clone> Model<T> {
pub fn clone_library_ref(&self) -> T {
self.lib.clone()
}
}
impl<T: Borrow<StanLibrary>> Drop for Model<T> {
fn drop(&mut self) {
unsafe { self.ffi_lib().bs_model_destruct(self.model.as_ptr()) }
}
}