optimal_transport_rs/
lib.rs1use numpy::{PyArray1, PyArray2};
2use numpy::convert::IntoPyArray;
3use pyo3::prelude::*;
4use pyo3::wrap_pyfunction;
5mod sinkhorn;
6mod lp_solver;
7
8use sinkhorn::sinkhorn as impl_sinkhorn;
9use lp_solver::calculate_1D_ot as impl_calculate_1D_ot;
10
11#[pyfunction]
12fn calculate_1D_ot(py: Python<'_>, a: &PyArray1<i32>, b: &PyArray1<i32>, cost: &PyArray2<i32>) -> PyResult<(i32, Py<PyArray2<u32>>)> {
13 let (cost, transport_plan) = impl_calculate_1D_ot(&a.to_owned_array(), &b.to_owned_array(), &cost.to_owned_array());
14 Ok((cost, transport_plan.into_pyarray(py).to_owned()))
15}
16
17#[pyfunction]
18fn sinkhorn(py: Python<'_>, a: &PyArray1<f32>, b: &PyArray1<f32>, cost: &PyArray2<f32>, reg: f32) -> PyResult<Py<PyArray2<f32>>> {
19 let transport_plan = impl_sinkhorn(a.to_owned_array(), b.to_owned_array(), cost.to_owned_array(), reg);
20 Ok(transport_plan.into_pyarray(py).to_owned())
21}
22
23#[pymodule]
24fn rust(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
25 m.add_wrapped(wrap_pyfunction!(sinkhorn))?;
26 m.add_wrapped(wrap_pyfunction!(calculate_1D_ot))?;
27
28 Ok(())
29}