use crate::{
multivariate::MultivariateDensity,
pytypes::{Float, PyUnivariate},
};
use nalgebra::Dyn;
use pyo3::{Bound, PyResult, exceptions::PyTypeError, prelude::*, types::PyList};
#[derive(Clone)]
#[pyclass(from_py_object, name = "Multivariate")]
pub struct PyMultivariate {
#[pyo3(get)]
uvpdfs: Vec<PyUnivariate>,
}
impl PyMultivariate {
pub fn uvpdfs(&self) -> &Vec<PyUnivariate> {
&self.uvpdfs
}
}
impl From<MultivariateDensity<Float, Dyn>> for PyMultivariate {
fn from(dist: MultivariateDensity<Float, Dyn>) -> Self {
let uvpdfs = dist
.marginals()
.iter()
.map(|uv| PyUnivariate::from(uv.clone()))
.collect();
Self { uvpdfs }
}
}
#[pymethods]
impl PyMultivariate {
#[new]
pub fn new<'py>(priors: Bound<'py, PyList>) -> PyResult<Self> {
let priors = priors
.iter()
.map(|p| match p.extract::<PyUnivariate>() {
Ok(uvpdf) => Some(uvpdf.clone()),
Err(_) => None,
})
.collect::<Vec<Option<PyUnivariate>>>();
if priors.iter().any(|r| r.is_none()) {
Err(PyTypeError::new_err(
"failed to convert one of the members of the list argument to a PyUnivariate type",
))
} else {
Ok(Self {
uvpdfs: priors.into_iter().map(|uvpdf| uvpdf.unwrap()).collect(),
})
}
}
pub fn names(&self) -> Vec<String> {
self.uvpdfs.iter().map(|u| u.name().to_string()).collect()
}
pub fn typenames(&self) -> Vec<String> {
self.uvpdfs
.iter()
.map(|u| u.typename().to_string())
.collect()
}
}