use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
#[cfg(feature = "radar_examples")]
use pyo3::types::PyComplex;
use pyo3::types::{PyDict, PyList};
use std::collections::HashMap;
use std::sync::Arc;
use crate::builder::Graph;
use crate::dag::{Dag, PredictTarget};
use crate::distribution::{DistContext, Distribution};
use crate::graph_data::GraphData;
use crate::stat_result::StatResult;
#[pyclass(name = "Distribution")]
#[derive(Clone)]
struct PyDistribution {
inner: Distribution,
}
#[pymethods]
impl PyDistribution {
#[getter]
fn mean(&self) -> f64 {
self.inner.mean()
}
#[getter]
fn std(&self) -> f64 {
self.inner.std()
}
#[getter]
fn variance(&self) -> f64 {
self.inner.variance()
}
#[getter]
fn p5(&self) -> f64 {
self.inner.percentile(0.05)
}
#[getter]
fn p50(&self) -> f64 {
self.inner.percentile(0.50)
}
#[getter]
fn p95(&self) -> f64 {
self.inner.percentile(0.95)
}
fn percentile(&self, p: f64) -> f64 {
self.inner.percentile(p)
}
#[getter]
fn samples(&self, py: Python) -> PyObject {
match &self.inner {
Distribution::Empirical { samples } => {
let v: Vec<f64> = samples.as_ref().clone();
v.to_object(py)
}
_ => py.None(),
}
}
fn sample_n(&self, n: usize) -> Vec<f64> {
self.inner.sample_n(n)
}
fn summary(&self) -> String {
format!("{}", self.inner.summary())
}
fn __repr__(&self) -> String {
format!("{}", self.inner)
}
}
#[pyclass(name = "StatResult")]
struct PyStatResult {
inner: StatResult,
}
#[pymethods]
impl PyStatResult {
fn __getitem__(&self, py: Python, key: &str) -> PyResult<PyObject> {
match self.inner.get(key) {
Some(dist) => Ok(PyDistribution { inner: dist.clone() }.into_py(py)),
None => Err(PyValueError::new_err(format!(
"Variable '{}' not found in StatResult",
key
))),
}
}
fn get(&self, py: Python, key: &str) -> PyObject {
match self.inner.get(key) {
Some(dist) => PyDistribution { inner: dist.clone() }.into_py(py),
None => py.None(),
}
}
fn for_branch(&self, py: Python, branch_id: usize) -> PyObject {
dist_context_to_py_dict(py, self.inner.for_branch(branch_id))
}
fn for_variant(&self, py: Python, variant_idx: usize) -> PyObject {
dist_context_to_py_dict(py, self.inner.for_variant(variant_idx))
}
fn keys(&self, py: Python) -> PyObject {
let mut ks: Vec<&str> = self
.inner
.dist_context
.keys()
.filter(|k| !k.starts_with("__branch_"))
.map(|k| k.as_str())
.collect();
ks.sort();
ks.to_object(py)
}
fn print_summary(&self) {
self.inner.print_summary();
}
fn __repr__(&self) -> String {
let keys: Vec<&str> = self
.inner
.dist_context
.keys()
.filter(|k| !k.starts_with("__branch_"))
.map(|k| k.as_str())
.collect();
format!("StatResult(vars={:?})", keys)
}
#[getter]
fn particles(&self, py: Python) -> PyObject {
match &self.inner.particles {
None => py.None(),
Some(parts) => {
let py_list = pyo3::types::PyList::empty(py);
for particle in parts {
let d = PyDict::new(py);
for (k, v) in particle {
let _ = d.set_item(k, v);
}
let _ = py_list.append(d);
}
py_list.to_object(py)
}
}
}
}
fn dist_context_to_py_dict(py: Python, ctx: Option<&DistContext>) -> PyObject {
let dict = PyDict::new(py);
if let Some(c) = ctx {
for (k, v) in c {
let _ = dict.set_item(k, PyDistribution { inner: v.clone() }.into_py(py));
}
}
dict.to_object(py)
}
#[pyclass(name = "Graph")]
struct PyGraph {
graph: Option<Graph>,
}
#[pymethods]
impl PyGraph {
#[new]
fn new() -> Self {
PyGraph {
graph: Some(Graph::new()),
}
}
#[pyo3(signature = (function=None, label=None, inputs=None, outputs=None))]
fn add(
&mut self,
function: Option<PyObject>,
label: Option<String>,
inputs: Option<&PyAny>,
outputs: Option<&PyAny>,
) -> PyResult<()> {
let graph = self
.graph
.as_mut()
.ok_or_else(|| PyValueError::new_err("Graph has already been built or consumed"))?;
let input_vec = if let Some(inp) = inputs {
parse_mapping(inp)?
} else {
Vec::new()
};
let output_vec = if let Some(out) = outputs {
parse_mapping(out)?
} else {
Vec::new()
};
let input_refs: Vec<(&str, &str)> = input_vec
.iter()
.map(|(a, b)| (a.as_str(), b.as_str()))
.collect();
let output_refs: Vec<(&str, &str)> = output_vec
.iter()
.map(|(a, b)| (a.as_str(), b.as_str()))
.collect();
if let Some(py_func) = function {
let rust_function = create_python_node_function(py_func);
graph.add(
rust_function,
label.as_deref(),
if input_refs.is_empty() {
None
} else {
Some(input_refs)
},
if output_refs.is_empty() {
None
} else {
Some(output_refs)
},
);
} else {
let noop = |_: &HashMap<String, GraphData>| HashMap::new();
graph.add(
noop,
label.as_deref(),
if input_refs.is_empty() {
None
} else {
Some(input_refs)
},
if output_refs.is_empty() {
None
} else {
Some(output_refs)
},
);
}
Ok(())
}
fn branch(&mut self, mut subgraph: PyRefMut<PyGraph>) -> PyResult<usize> {
let graph = self
.graph
.as_mut()
.ok_or_else(|| PyValueError::new_err("Graph has already been built or consumed"))?;
let subgraph_inner = subgraph
.graph
.take()
.ok_or_else(|| PyValueError::new_err("Subgraph has already been built or consumed"))?;
Ok(graph.branch(subgraph_inner))
}
#[pyo3(signature = (functions, label=None, inputs=None, outputs=None))]
fn variants(
&mut self,
functions: Vec<PyObject>,
label: Option<String>,
inputs: Option<&PyAny>,
outputs: Option<&PyAny>,
) -> PyResult<()> {
let graph = self
.graph
.as_mut()
.ok_or_else(|| PyValueError::new_err("Graph has already been built or consumed"))?;
let input_vec = if let Some(inp) = inputs {
parse_mapping(inp)?
} else {
Vec::new()
};
let output_vec = if let Some(out) = outputs {
parse_mapping(out)?
} else {
Vec::new()
};
let input_refs: Vec<(&str, &str)> = input_vec
.iter()
.map(|(a, b)| (a.as_str(), b.as_str()))
.collect();
let output_refs: Vec<(&str, &str)> = output_vec
.iter()
.map(|(a, b)| (a.as_str(), b.as_str()))
.collect();
let rust_functions: Vec<_> = functions
.iter()
.map(|func| create_python_node_function(func.clone()))
.collect();
graph.variants(
rust_functions,
label.as_deref(),
if input_refs.is_empty() {
None
} else {
Some(input_refs)
},
if output_refs.is_empty() {
None
} else {
Some(output_refs)
},
);
Ok(())
}
fn build(&mut self) -> PyResult<PyDag> {
let graph = self
.graph
.take()
.ok_or_else(|| PyValueError::new_err("Graph has already been built"))?;
Ok(PyDag { dag: graph.build() })
}
fn set_dist_transfer(&mut self, label: String, transfer_fn: PyObject) -> PyResult<()> {
let graph = self
.graph
.as_mut()
.ok_or_else(|| PyValueError::new_err("Graph has already been built or consumed"))?;
let rust_transfer = create_python_dist_transfer(transfer_fn);
graph.set_dist_transfer_for(&label, Arc::new(rust_transfer));
Ok(())
}
}
#[pyclass(name = "Dag")]
struct PyDag {
dag: Dag,
}
#[pymethods]
impl PyDag {
#[pyo3(signature = (parallel=false, max_threads=None))]
fn execute(
&self,
py: Python,
parallel: bool,
max_threads: Option<usize>,
) -> PyResult<PyObject> {
let context = py.allow_threads(|| self.dag.execute(parallel, max_threads));
let py_dict = PyDict::new(py);
for (key, value) in context.iter() {
py_dict.set_item(key, graph_data_to_python(py, value))?;
}
Ok(py_dict.to_object(py))
}
fn to_mermaid(&self) -> String {
self.dag.to_mermaid()
}
fn node_count(&self) -> usize {
self.dag.nodes().len()
}
fn node_labels(&self, py: Python) -> PyObject {
let mut labels: Vec<&str> = self
.dag
.nodes()
.iter()
.filter_map(|n| n.label.as_deref())
.collect();
labels.sort();
labels.dedup();
labels.to_object(py)
}
fn branch_ids(&self, py: Python) -> PyObject {
let mut ids: Vec<usize> = self
.dag
.nodes()
.iter()
.filter_map(|n| n.branch_id)
.collect();
ids.sort();
ids.dedup();
ids.to_object(py)
}
fn variant_indices(&self, py: Python) -> PyObject {
let mut idxs: Vec<usize> = self
.dag
.nodes()
.iter()
.filter_map(|n| n.variant_index)
.collect();
idxs.sort();
idxs.dedup();
idxs.to_object(py)
}
#[pyo3(signature = (inputs, n_samples=None, at_node=None, at_branch=None, at_variant=None))]
fn predict_at(
&self,
py: Python,
inputs: &PyDict,
n_samples: Option<usize>,
at_node: Option<String>,
at_branch: Option<usize>,
at_variant: Option<usize>,
) -> PyResult<PyStatResult> {
let mut dist_ctx: DistContext = HashMap::new();
for (key, val) in inputs.iter() {
let k: String = key.extract()?;
let cell = val
.downcast::<PyCell<PyDistribution>>()
.map_err(|_| PyValueError::new_err(format!(
"Value for key '{}' must be a Distribution (use dagex.normal(), dagex.gamma(), etc.)",
k
)))?;
dist_ctx.insert(k, cell.borrow().inner.clone());
}
let target: Option<PredictTarget> = if let Some(label) = at_node {
Some(PredictTarget::NodeLabel(label))
} else if let Some(bid) = at_branch {
Some(PredictTarget::BranchId(bid))
} else if let Some(vi) = at_variant {
Some(PredictTarget::VariantIndex(vi))
} else {
None
};
let stat = py.allow_threads(|| {
self.dag.predict_at(dist_ctx, n_samples, target.as_ref())
});
Ok(PyStatResult { inner: stat })
}
#[pyo3(signature = (inputs, n_samples=1000))]
fn predict(
&self,
py: Python,
inputs: &PyDict,
n_samples: usize,
) -> PyResult<PyStatResult> {
let mut dist_ctx: DistContext = HashMap::new();
for (key, val) in inputs.iter() {
let k: String = key.extract()?;
let cell = val
.downcast::<PyCell<PyDistribution>>()
.map_err(|_| PyValueError::new_err(format!(
"Value for key '{}' must be a Distribution",
k
)))?;
dist_ctx.insert(k, cell.borrow().inner.clone());
}
let stat = py.allow_threads(|| {
self.dag.predict(dist_ctx, n_samples)
});
Ok(PyStatResult { inner: stat })
}
}
fn parse_mapping(obj: &PyAny) -> PyResult<Vec<(String, String)>> {
if let Ok(dict) = obj.downcast::<PyDict>() {
let mut result = Vec::new();
for (key, value) in dict.iter() {
let k: String = key.extract()?;
let v: String = value.extract()?;
result.push((k, v));
}
Ok(result)
} else if let Ok(list) = obj.downcast::<PyList>() {
let mut result = Vec::new();
for item in list.iter() {
let tuple: (String, String) = item.extract()?;
result.push(tuple);
}
Ok(result)
} else {
Err(PyValueError::new_err(
"inputs/outputs must be a dict or list of tuples",
))
}
}
fn create_python_node_function(
py_func: PyObject,
) -> impl Fn(&HashMap<String, GraphData>) -> HashMap<String, GraphData>
+ Send
+ Sync
+ 'static {
let py_func = Arc::new(py_func);
move |inputs: &HashMap<String, GraphData>| {
Python::with_gil(|py| {
let py_inputs = PyDict::new(py);
for (key, value) in inputs.iter() {
if let Err(e) = py_inputs.set_item(key, graph_data_to_python(py, value)) {
let _ = py
.import("sys")
.and_then(|sys| sys.getattr("stderr"))
.and_then(|stderr| {
stderr.call_method1(
"write",
(format!("Error setting input '{}': {}\n", key, e),),
)
});
return HashMap::new();
}
}
let result = py_func.call1(py, (py_inputs,));
match result {
Ok(py_result) => {
if let Ok(result_dict) = py_result.downcast::<PyDict>(py) {
let mut output = HashMap::new();
for (key, value) in result_dict.iter() {
if let Ok(k) = key.extract::<String>() {
output.insert(k, python_to_graph_data(value));
}
}
output
} else {
let _ = py
.import("sys")
.and_then(|sys| sys.getattr("stderr"))
.and_then(|stderr| {
stderr.call_method1(
"write",
("Error: Python function did not return a dict\n",),
)
});
HashMap::new()
}
}
Err(e) => {
e.print(py);
HashMap::new()
}
}
})
}
}
fn graph_data_to_python(py: Python, data: &GraphData) -> PyObject {
match data {
GraphData::Int(v) => v.to_object(py),
GraphData::Float(v) => v.to_object(py),
GraphData::String(s) => s.to_object(py),
GraphData::FloatVec(v) => v.to_object(py),
GraphData::IntVec(v) => v.to_object(py),
GraphData::Map(m) => {
let mut is_complex_array = true;
let mut max_idx = 0;
for (k, v) in m.iter() {
if let Ok(idx) = k.parse::<usize>() {
if idx > max_idx {
max_idx = idx;
}
if let Some(inner_map) = v.as_map() {
if !inner_map.contains_key("re") || !inner_map.contains_key("im") {
is_complex_array = false;
break;
}
} else {
is_complex_array = false;
break;
}
} else {
is_complex_array = false;
break;
}
}
if is_complex_array && !m.is_empty() && m.len() == max_idx + 1 {
let list = PyList::empty(py);
for i in 0..m.len() {
if let Some(v) = m.get(&i.to_string()) {
if let Some(inner_map) = v.as_map() {
let re = inner_map
.get("re")
.and_then(|d| d.as_float())
.unwrap_or(0.0);
let im = inner_map
.get("im")
.and_then(|d| d.as_float())
.unwrap_or(0.0);
let _ = list.append((re, im).to_object(py));
}
}
}
return list.to_object(py);
}
let mut is_list = true;
let mut max_idx = 0;
for k in m.keys() {
if let Ok(idx) = k.parse::<usize>() {
if idx > max_idx {
max_idx = idx;
}
} else {
is_list = false;
break;
}
}
if is_list && !m.is_empty() && m.len() == max_idx + 1 {
let list = PyList::empty(py);
for i in 0..m.len() {
if let Some(v) = m.get(&i.to_string()) {
let _ = list.append(graph_data_to_python(py, v));
}
}
list.to_object(py)
} else {
let dict = PyDict::new(py);
for (k, v) in m.iter() {
let _ = dict.set_item(k, graph_data_to_python(py, v));
}
dict.to_object(py)
}
}
GraphData::None => py.None(),
#[cfg(feature = "python")]
GraphData::PyObject(obj) => {
obj.clone_ref(py)
}
#[cfg(feature = "radar_examples")]
GraphData::Complex(c) => {
PyComplex::from_doubles(py, c.re, c.im).to_object(py)
}
#[cfg(feature = "radar_examples")]
GraphData::FloatArray(a) => {
a.to_vec().to_object(py)
}
#[cfg(feature = "radar_examples")]
GraphData::ComplexArray(a) => {
let list = PyList::empty(py);
for c in a.iter() {
let py_complex = PyComplex::from_doubles(py, c.re, c.im);
let _ = list.append(py_complex);
}
list.to_object(py)
}
}
}
fn python_to_graph_data(obj: &PyAny) -> GraphData {
if let Ok(f) = obj.extract::<f64>() {
return GraphData::Float(f);
}
if let Ok(i) = obj.extract::<i64>() {
return GraphData::Int(i);
}
if let Ok(s) = obj.extract::<String>() {
return GraphData::String(s);
}
if let Ok(list) = obj.extract::<Vec<f64>>() {
return GraphData::FloatVec(std::sync::Arc::new(list));
}
if let Ok(list) = obj.extract::<Vec<i64>>() {
return GraphData::IntVec(std::sync::Arc::new(list));
}
GraphData::PyObject(obj.to_object(obj.py()))
}
fn create_python_dist_transfer(
py_func: PyObject,
) -> impl Fn(&DistContext) -> Option<DistContext> + Send + Sync + 'static {
let py_func = Arc::new(py_func);
move |input_dists: &DistContext| -> Option<DistContext> {
Python::with_gil(|py| {
let py_dict = PyDict::new(py);
for (key, dist) in input_dists {
let d = PyDistribution { inner: dist.clone() };
py_dict.set_item(key, d.into_py(py)).ok()?;
}
let result = py_func.call1(py, (py_dict,)).ok()?;
if result.is_none(py) {
return None;
}
let result_dict = result.downcast::<PyDict>(py).ok()?;
let mut output: DistContext = HashMap::new();
for (key, val) in result_dict.iter() {
let k: String = key.extract().ok()?;
if let Ok(cell) = val.downcast::<PyCell<PyDistribution>>() {
output.insert(k, cell.borrow().inner.clone());
}
}
if output.is_empty() {
None
} else {
Some(output)
}
})
}
}
#[pyfunction]
#[pyo3(signature = (mean, std))]
fn normal(mean: f64, std: f64) -> PyDistribution {
PyDistribution {
inner: Distribution::normal(mean, std),
}
}
#[pyfunction]
#[pyo3(signature = (low, high))]
fn uniform(low: f64, high: f64) -> PyDistribution {
PyDistribution {
inner: Distribution::uniform(low, high),
}
}
#[pyfunction]
#[pyo3(signature = (alpha, beta))]
fn beta(alpha: f64, beta: f64) -> PyDistribution {
PyDistribution {
inner: Distribution::beta(alpha, beta),
}
}
#[pyfunction]
#[pyo3(signature = (shape, rate))]
fn gamma(shape: f64, rate: f64) -> PyDistribution {
PyDistribution {
inner: Distribution::gamma(shape, rate),
}
}
#[pyfunction]
#[pyo3(signature = (mu, sigma))]
fn lognormal(mu: f64, sigma: f64) -> PyDistribution {
PyDistribution {
inner: Distribution::lognormal(mu, sigma),
}
}
#[pyfunction]
#[pyo3(signature = (value))]
fn deterministic(value: f64) -> PyDistribution {
PyDistribution {
inner: Distribution::deterministic(value),
}
}
#[pyfunction]
#[pyo3(signature = (samples))]
fn empirical(samples: Vec<f64>) -> PyDistribution {
PyDistribution {
inner: Distribution::empirical(samples),
}
}
#[pymodule]
fn dagex(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyGraph>()?;
m.add_class::<PyDag>()?;
m.add_class::<PyDistribution>()?;
m.add_class::<PyStatResult>()?;
m.add_function(wrap_pyfunction!(normal, m)?)?;
m.add_function(wrap_pyfunction!(uniform, m)?)?;
m.add_function(wrap_pyfunction!(beta, m)?)?;
m.add_function(wrap_pyfunction!(gamma, m)?)?;
m.add_function(wrap_pyfunction!(lognormal, m)?)?;
m.add_function(wrap_pyfunction!(deterministic, m)?)?;
m.add_function(wrap_pyfunction!(empirical, m)?)?;
Ok(())
}