kiss_icp_pybind/
lib.rs

1#![allow(non_snake_case)]
2
3use kiss_icp_core::{
4    deskew, metrics, preprocessing, runtime,
5    threshold::AdaptiveThreshold,
6    types::{IntoIsometry3, IntoVoxelPoint, PyIsometry3, PyVoxelPoint},
7    voxel_hash_map::{VoxelHashMap, VoxelHashMapArgs},
8};
9use numpy::{
10    nalgebra::{Dyn, MatrixXx3, U3},
11    PyArray2, ToPyArray,
12};
13use pyo3::{
14    exceptions::PyException, pyclass, pyfunction, pymethods, pymodule, types::PyModule,
15    wrap_pyfunction, PyObject, PyResult, Python,
16};
17use rayon::iter::ParallelIterator;
18
19type PyListVoxelPoint<'py> = ::numpy::PyReadonlyArray2<'py, f64>;
20
21#[pyfunction]
22fn _Vector3dVector(vec: PyObject) -> PyObject {
23    vec
24}
25
26/// Map representation
27#[pyclass]
28struct _VoxelHashMap(VoxelHashMap);
29
30#[pymethods]
31impl _VoxelHashMap {
32    #[new]
33    fn new(voxel_size: f64, max_distance: f64, max_points_per_voxel: usize) -> Self {
34        let args = VoxelHashMapArgs {
35            max_distance2: max_distance * max_distance,
36            max_points_per_voxel,
37        };
38
39        Self(VoxelHashMap::new(args, voxel_size))
40    }
41
42    fn _clear(&mut self) {
43        self.0.clear()
44    }
45
46    fn _empty(&self) -> bool {
47        self.0.is_empty()
48    }
49
50    fn _update(&mut self, points: PyListVoxelPoint, pose: PyIsometry3) {
51        self._update_with_pose(points, pose)
52    }
53
54    #[inline]
55    fn _update_with_origin(&mut self, points: PyListVoxelPoint, origin: PyVoxelPoint) {
56        let points = points.try_as_matrix::<Dyn, U3, Dyn, Dyn>().unwrap();
57        self.0.update_with_origin(&points, origin)
58    }
59
60    #[inline]
61    fn _update_with_pose(&mut self, points: PyListVoxelPoint, pose: PyIsometry3) {
62        let points = points.try_as_matrix::<Dyn, U3, Dyn, Dyn>().unwrap();
63        self.0.update_with_pose(&points, pose)
64    }
65
66    fn _add_points(&mut self, points: PyListVoxelPoint) {
67        let points = points.try_as_matrix::<Dyn, U3, Dyn, Dyn>().unwrap();
68        self.0.add_points(&points)
69    }
70
71    fn _remove_far_away_points(&mut self, origin: PyVoxelPoint) {
72        self.0.remove_points_far_from_location(origin)
73    }
74
75    fn _point_cloud<'py>(&self, py: Python<'py>) -> &'py PyArray2<f64> {
76        let points: Vec<_> = self
77            .0
78            .get_point_cloud()
79            .map(|point| point.transpose())
80            .collect();
81        MatrixXx3::from_rows(&points).to_pyarray(py)
82    }
83
84    fn _get_correspondences<'py>(
85        &self,
86        py: Python<'py>,
87        points: PyListVoxelPoint,
88        max_correspondance_distance: f64,
89    ) -> (&'py PyArray2<f64>, &'py PyArray2<f64>) {
90        let points = points.try_as_matrix::<Dyn, U3, Dyn, Dyn>().unwrap();
91        let (points, closest_neighbors): (Vec<_>, Vec<_>) = self
92            .0
93            .get_correspondences(&points, max_correspondance_distance)
94            .map(|(point, closest_neighbor)| (point.transpose(), closest_neighbor.transpose()))
95            .unzip();
96        (
97            MatrixXx3::from_rows(&points).to_pyarray(py),
98            MatrixXx3::from_rows(&closest_neighbors).to_pyarray(py),
99        )
100    }
101}
102
103/// Point Cloud registration
104#[pyfunction]
105fn _register_point_cloud(
106    points: PyListVoxelPoint,
107    voxel_map: &_VoxelHashMap,
108    initial_guess: PyIsometry3,
109    max_correspondance_distance: f64,
110    kernel: f64,
111) -> PyIsometry3 {
112    let points = points
113        .try_as_matrix::<Dyn, U3, Dyn, Dyn>()
114        .unwrap()
115        .clone_owned();
116    voxel_map
117        .0
118        .register_frame(points, initial_guess, max_correspondance_distance, kernel)
119        .into_py_isometry3()
120}
121
122/// AdaptiveThreshold bindings
123#[pyclass]
124struct _AdaptiveThreshold(AdaptiveThreshold);
125
126#[pymethods]
127impl _AdaptiveThreshold {
128    #[new]
129    fn new(initial_threshold: f64, min_motion_th: f64, max_range: f64) -> Self {
130        Self(AdaptiveThreshold::new(
131            initial_threshold,
132            min_motion_th,
133            max_range,
134        ))
135    }
136
137    fn _compute_threshold(&mut self) -> f64 {
138        self.0.compute_threshold()
139    }
140
141    fn _update_model_deviation(&mut self, model_deviation: PyIsometry3) {
142        self.0.update_model_deviation(model_deviation)
143    }
144}
145
146/// DeSkewScan
147#[pyfunction]
148fn _deskew_scan(
149    frame: Vec<PyVoxelPoint>,
150    timestamps: Vec<f64>,
151    start_pose: PyIsometry3,
152    finish_pose: PyIsometry3,
153) -> Vec<PyVoxelPoint> {
154    deskew::scan(&frame, &timestamps, start_pose, finish_pose)
155        .map(IntoVoxelPoint::into_py_voxel_point)
156        .collect()
157}
158
159// preprocessing modules
160
161#[pyfunction]
162fn _voxel_down_sample<'py>(
163    py: Python<'py>,
164    frame: PyListVoxelPoint<'py>,
165    voxel_size: f64,
166) -> &'py PyArray2<f64> {
167    let frame = frame.try_as_matrix::<Dyn, U3, Dyn, Dyn>().unwrap();
168    let downsampled: Vec<_> = preprocessing::voxel_downsample(&frame, voxel_size)
169        .map(|point| point.transpose())
170        .collect();
171    MatrixXx3::from_rows(&downsampled).to_pyarray(py)
172}
173
174#[pyfunction]
175fn _preprocess<'py>(
176    py: Python<'py>,
177    frame: PyListVoxelPoint<'py>,
178    max_range: f64,
179    min_range: f64,
180) -> &'py PyArray2<f64> {
181    let frame = frame.try_as_matrix::<Dyn, U3, Dyn, Dyn>().unwrap();
182    let preprocessed = preprocessing::preprocess(&frame, min_range..max_range)
183        .map(|point| point.transpose())
184        .collect::<Vec<_>>();
185    MatrixXx3::from_rows(&preprocessed).to_pyarray(py)
186}
187
188#[pyfunction]
189fn _correct_kitti_scan<'py>(py: Python<'py>, frame: PyListVoxelPoint<'py>) -> &'py PyArray2<f64> {
190    let frame = frame.try_as_matrix::<Dyn, U3, Dyn, Dyn>().unwrap();
191    let corrected: Vec<_> = preprocessing::correct_kitti_scan(&frame)
192        .map(|point| point.transpose())
193        .collect();
194    MatrixXx3::from_rows(&corrected).to_pyarray(py)
195}
196
197// Metrics
198
199#[pyfunction]
200fn _kitti_seq_error(gt_poses: Vec<PyIsometry3>, results_poses: Vec<PyIsometry3>) -> (f64, f64) {
201    metrics::seq_error(&gt_poses, &results_poses)
202}
203
204#[pyfunction]
205fn _absolute_trajectory_error(
206    gt_poses: Vec<PyIsometry3>,
207    results_poses: Vec<PyIsometry3>,
208) -> (f64, f64) {
209    metrics::absolute_trajectory_error(&gt_poses, &results_poses)
210}
211
212#[pymodule]
213fn kiss_icp_pybind(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
214    // optimize performance
215    runtime::init(runtime::SystemType::Library).map_err(|error| {
216        PyException::new_err(format!(
217            "failed to init {name}: {error}",
218            name = env!("CARGO_CRATE_NAME"),
219        ))
220    })?;
221
222    m.add_function(wrap_pyfunction!(_Vector3dVector, m)?)?;
223
224    m.add_class::<_VoxelHashMap>()?;
225
226    // Point Cloud registration
227    m.add_function(wrap_pyfunction!(_register_point_cloud, m)?)?;
228
229    // AdaptiveThreshold bindings
230    m.add_class::<_AdaptiveThreshold>()?;
231
232    // DeSkewScan
233    m.add_function(wrap_pyfunction!(_deskew_scan, m)?)?;
234
235    // preprocessing modules
236    m.add_function(wrap_pyfunction!(_voxel_down_sample, m)?)?;
237    m.add_function(wrap_pyfunction!(_preprocess, m)?)?;
238    m.add_function(wrap_pyfunction!(_correct_kitti_scan, m)?)?;
239
240    // Metrics
241    m.add_function(wrap_pyfunction!(_kitti_seq_error, m)?)?;
242    m.add_function(wrap_pyfunction!(_absolute_trajectory_error, m)?)?;
243
244    Ok(())
245}