mod py_tracing;
mod simplify_path;
use std::path::PathBuf;
use std::sync::mpsc;
use pyo3::exceptions::{PyException, PyIOError};
use pyo3::prelude::*;
use crate::error::{Error, Result};
use crate::routing::RouteDefinition;
use crate::{ArrayIndex, RevrtRoutingSolutions, Solution, resolve, resolve_generator};
type PyRoutePoint = (u64, u64);
type PyPossibleRouteNodes = Vec<PyRoutePoint>;
type PyRouteResult = (Vec<PyRoutePoint>, f32);
type PyRoutingSolutions = Vec<PyRouteResult>;
type PyRouteYield = PyResult<Option<(u32, PyRoutingSolutions)>>;
type PyRouteDefinition = (u32, PyPossibleRouteNodes, PyPossibleRouteNodes);
impl From<&PyRouteDefinition> for RouteDefinition {
fn from(route: &PyRouteDefinition) -> RouteDefinition {
let (id, start_points, end_points) = route;
RouteDefinition {
route_id: *id,
start_inds: start_points
.iter()
.map(|(i, j)| ArrayIndex { i: *i, j: *j })
.collect(),
end_inds: end_points
.iter()
.map(|(i, j)| ArrayIndex { i: *i, j: *j })
.collect(),
}
}
}
impl From<Solution<ArrayIndex, f32>> for PyRouteResult {
fn from(solution: Solution<ArrayIndex, f32>) -> Self {
let Solution { route, total_cost } = solution;
let path = route.into_iter().map(Into::into).collect();
(path, total_cost)
}
}
pyo3::create_exception!(_rust, revrtRustError, PyException);
impl From<Error> for PyErr {
fn from(err: Error) -> PyErr {
match err {
Error::IO(msg) => PyIOError::new_err(msg),
Error::ZarrsArrayCreate(e) => PyIOError::new_err(e.to_string()),
Error::ZarrsArray(e) => PyIOError::new_err(e.to_string()),
Error::ZarrsStorage(e) => PyIOError::new_err(e.to_string()),
Error::ZarrsGroupCreate(e) => PyIOError::new_err(e.to_string()),
Error::Undefined(msg) => revrtRustError::new_err(msg),
}
}
}
#[pymodule]
fn _rust(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(find_paths, m)?)?;
m.add_function(wrap_pyfunction!(simplify_using_slopes, m)?)?;
m.add("revrtRustError", py.get_type::<revrtRustError>())?;
m.add_class::<RouteFinder>()?;
Ok(())
}
#[pyfunction]
#[pyo3(signature = (path, slope_tolerance=0.5))]
fn simplify_using_slopes(path: Vec<(f64, f64)>, slope_tolerance: f64) -> Vec<(f64, f64)> {
simplify_path::simplify_path(path, slope_tolerance)
}
#[pyfunction]
#[pyo3(signature = (zarr_fp, cost_function, start, end, cache_size=250_000_000))]
#[allow(clippy::type_complexity)]
fn find_paths(
zarr_fp: PathBuf,
cost_function: String,
start: Vec<(u64, u64)>,
end: Vec<(u64, u64)>,
cache_size: u64,
) -> Result<PyRoutingSolutions> {
let start: Vec<ArrayIndex> = start
.into_iter()
.map(|(i, j)| ArrayIndex { i, j })
.collect();
let end: Vec<ArrayIndex> = end.into_iter().map(|(i, j)| ArrayIndex { i, j }).collect();
let paths = resolve(zarr_fp, &cost_function, cache_size, &start, end)?;
Ok(paths.into_iter().map(Into::into).collect())
}
#[pyclass]
struct RouteFinder {
zarr_fp: PathBuf,
cost_function: String,
route_definitions: Vec<PyRouteDefinition>,
cache_size: u64,
}
#[pymethods]
impl RouteFinder {
#[new]
#[pyo3(signature = (zarr_fp, cost_function, route_definitions, cache_size=250_000_000, log_level=None))]
fn new(
zarr_fp: PathBuf,
cost_function: String,
route_definitions: Vec<PyRouteDefinition>,
cache_size: u64,
log_level: Option<u8>,
) -> PyResult<Self> {
py_tracing::configure(log_level).map_err(PyErr::from)?;
Ok(Self {
zarr_fp,
cost_function,
route_definitions,
cache_size,
})
}
fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<RouteOutputIter>> {
let iter = RouteOutputIter::new(&slf)?;
Py::new(slf.py(), iter)
}
fn __str__(&self) -> PyResult<String> {
Ok(format!(
"`RouteFinder` instance for {} routes",
self.route_definitions.len()
))
}
}
#[pyclass(unsendable)]
struct RouteOutputIter {
receiver: Option<mpsc::Receiver<(u32, RevrtRoutingSolutions)>>,
finished: bool,
}
impl RouteOutputIter {
fn new(user_input: &RouteFinder) -> Result<Self> {
let (tx, rx) = mpsc::channel();
resolve_generator(
&user_input.zarr_fp,
&user_input.cost_function,
user_input
.route_definitions
.iter()
.map(Into::into)
.collect::<Vec<_>>(),
tx,
user_input.cache_size,
)?;
Ok(Self {
receiver: Some(rx),
finished: false,
})
}
}
#[pymethods]
impl RouteOutputIter {
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
fn __next__(mut slf: PyRefMut<'_, Self>) -> PyRouteYield {
if slf.finished {
return Ok(None);
}
let receiver = match slf.receiver.take() {
Some(receiver) => receiver,
None => {
slf.finished = true;
return Ok(None);
}
};
let py = slf.py();
let (recv_result, receiver) = py.detach(move || {
let result = receiver.recv();
(result, receiver)
});
slf.receiver = Some(receiver);
match recv_result {
Ok((id, solutions)) => Ok(Some((id, solutions.into_iter().map(Into::into).collect()))),
Err(_) => {
slf.finished = true;
Ok(None)
}
}
}
}