1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
use numpy::{PyArray1, PyArray2};
use numpy::convert::IntoPyArray;
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
mod sinkhorn;
mod lp_solver;

use sinkhorn::sinkhorn as impl_sinkhorn;
use lp_solver::calculate_1D_ot as impl_calculate_1D_ot;

#[pyfunction]
fn calculate_1D_ot(py: Python<'_>, a: &PyArray1<i32>, b: &PyArray1<i32>, cost: &PyArray2<i32>) -> PyResult<(i32, Py<PyArray2<u32>>)> {
   let (cost, transport_plan) = impl_calculate_1D_ot(&a.to_owned_array(), &b.to_owned_array(), &cost.to_owned_array());
   Ok((cost, transport_plan.into_pyarray(py).to_owned()))
}

#[pyfunction]
fn sinkhorn(py: Python<'_>, a: &PyArray1<f32>, b: &PyArray1<f32>, cost: &PyArray2<f32>, reg: f32) -> PyResult<Py<PyArray2<f32>>> {
   let transport_plan = impl_sinkhorn(a.to_owned_array(), b.to_owned_array(), cost.to_owned_array(), reg);
   Ok(transport_plan.into_pyarray(py).to_owned())
}

#[pymodule]
fn rust(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
    m.add_wrapped(wrap_pyfunction!(sinkhorn))?;
    m.add_wrapped(wrap_pyfunction!(calculate_1D_ot))?;

    Ok(())
}