1#![allow(non_snake_case)]
2
3use kiss_icp_core::{
4 deskew, metrics, preprocessing, runtime,
5 threshold::AdaptiveThreshold,
6 types::{IntoIsometry3, IntoVoxelPoint, PyIsometry3, PyVoxelPoint},
7 voxel_hash_map::{VoxelHashMap, VoxelHashMapArgs},
8};
9use numpy::{
10 nalgebra::{Dyn, MatrixXx3, U3},
11 PyArray2, ToPyArray,
12};
13use pyo3::{
14 exceptions::PyException, pyclass, pyfunction, pymethods, pymodule, types::PyModule,
15 wrap_pyfunction, PyObject, PyResult, Python,
16};
17use rayon::iter::ParallelIterator;
18
19type PyListVoxelPoint<'py> = ::numpy::PyReadonlyArray2<'py, f64>;
20
21#[pyfunction]
22fn _Vector3dVector(vec: PyObject) -> PyObject {
23 vec
24}
25
26#[pyclass]
28struct _VoxelHashMap(VoxelHashMap);
29
30#[pymethods]
31impl _VoxelHashMap {
32 #[new]
33 fn new(voxel_size: f64, max_distance: f64, max_points_per_voxel: usize) -> Self {
34 let args = VoxelHashMapArgs {
35 max_distance2: max_distance * max_distance,
36 max_points_per_voxel,
37 };
38
39 Self(VoxelHashMap::new(args, voxel_size))
40 }
41
42 fn _clear(&mut self) {
43 self.0.clear()
44 }
45
46 fn _empty(&self) -> bool {
47 self.0.is_empty()
48 }
49
50 fn _update(&mut self, points: PyListVoxelPoint, pose: PyIsometry3) {
51 self._update_with_pose(points, pose)
52 }
53
54 #[inline]
55 fn _update_with_origin(&mut self, points: PyListVoxelPoint, origin: PyVoxelPoint) {
56 let points = points.try_as_matrix::<Dyn, U3, Dyn, Dyn>().unwrap();
57 self.0.update_with_origin(&points, origin)
58 }
59
60 #[inline]
61 fn _update_with_pose(&mut self, points: PyListVoxelPoint, pose: PyIsometry3) {
62 let points = points.try_as_matrix::<Dyn, U3, Dyn, Dyn>().unwrap();
63 self.0.update_with_pose(&points, pose)
64 }
65
66 fn _add_points(&mut self, points: PyListVoxelPoint) {
67 let points = points.try_as_matrix::<Dyn, U3, Dyn, Dyn>().unwrap();
68 self.0.add_points(&points)
69 }
70
71 fn _remove_far_away_points(&mut self, origin: PyVoxelPoint) {
72 self.0.remove_points_far_from_location(origin)
73 }
74
75 fn _point_cloud<'py>(&self, py: Python<'py>) -> &'py PyArray2<f64> {
76 let points: Vec<_> = self
77 .0
78 .get_point_cloud()
79 .map(|point| point.transpose())
80 .collect();
81 MatrixXx3::from_rows(&points).to_pyarray(py)
82 }
83
84 fn _get_correspondences<'py>(
85 &self,
86 py: Python<'py>,
87 points: PyListVoxelPoint,
88 max_correspondance_distance: f64,
89 ) -> (&'py PyArray2<f64>, &'py PyArray2<f64>) {
90 let points = points.try_as_matrix::<Dyn, U3, Dyn, Dyn>().unwrap();
91 let (points, closest_neighbors): (Vec<_>, Vec<_>) = self
92 .0
93 .get_correspondences(&points, max_correspondance_distance)
94 .map(|(point, closest_neighbor)| (point.transpose(), closest_neighbor.transpose()))
95 .unzip();
96 (
97 MatrixXx3::from_rows(&points).to_pyarray(py),
98 MatrixXx3::from_rows(&closest_neighbors).to_pyarray(py),
99 )
100 }
101}
102
103#[pyfunction]
105fn _register_point_cloud(
106 points: PyListVoxelPoint,
107 voxel_map: &_VoxelHashMap,
108 initial_guess: PyIsometry3,
109 max_correspondance_distance: f64,
110 kernel: f64,
111) -> PyIsometry3 {
112 let points = points
113 .try_as_matrix::<Dyn, U3, Dyn, Dyn>()
114 .unwrap()
115 .clone_owned();
116 voxel_map
117 .0
118 .register_frame(points, initial_guess, max_correspondance_distance, kernel)
119 .into_py_isometry3()
120}
121
122#[pyclass]
124struct _AdaptiveThreshold(AdaptiveThreshold);
125
126#[pymethods]
127impl _AdaptiveThreshold {
128 #[new]
129 fn new(initial_threshold: f64, min_motion_th: f64, max_range: f64) -> Self {
130 Self(AdaptiveThreshold::new(
131 initial_threshold,
132 min_motion_th,
133 max_range,
134 ))
135 }
136
137 fn _compute_threshold(&mut self) -> f64 {
138 self.0.compute_threshold()
139 }
140
141 fn _update_model_deviation(&mut self, model_deviation: PyIsometry3) {
142 self.0.update_model_deviation(model_deviation)
143 }
144}
145
146#[pyfunction]
148fn _deskew_scan(
149 frame: Vec<PyVoxelPoint>,
150 timestamps: Vec<f64>,
151 start_pose: PyIsometry3,
152 finish_pose: PyIsometry3,
153) -> Vec<PyVoxelPoint> {
154 deskew::scan(&frame, ×tamps, start_pose, finish_pose)
155 .map(IntoVoxelPoint::into_py_voxel_point)
156 .collect()
157}
158
159#[pyfunction]
162fn _voxel_down_sample<'py>(
163 py: Python<'py>,
164 frame: PyListVoxelPoint<'py>,
165 voxel_size: f64,
166) -> &'py PyArray2<f64> {
167 let frame = frame.try_as_matrix::<Dyn, U3, Dyn, Dyn>().unwrap();
168 let downsampled: Vec<_> = preprocessing::voxel_downsample(&frame, voxel_size)
169 .map(|point| point.transpose())
170 .collect();
171 MatrixXx3::from_rows(&downsampled).to_pyarray(py)
172}
173
174#[pyfunction]
175fn _preprocess<'py>(
176 py: Python<'py>,
177 frame: PyListVoxelPoint<'py>,
178 max_range: f64,
179 min_range: f64,
180) -> &'py PyArray2<f64> {
181 let frame = frame.try_as_matrix::<Dyn, U3, Dyn, Dyn>().unwrap();
182 let preprocessed = preprocessing::preprocess(&frame, min_range..max_range)
183 .map(|point| point.transpose())
184 .collect::<Vec<_>>();
185 MatrixXx3::from_rows(&preprocessed).to_pyarray(py)
186}
187
188#[pyfunction]
189fn _correct_kitti_scan<'py>(py: Python<'py>, frame: PyListVoxelPoint<'py>) -> &'py PyArray2<f64> {
190 let frame = frame.try_as_matrix::<Dyn, U3, Dyn, Dyn>().unwrap();
191 let corrected: Vec<_> = preprocessing::correct_kitti_scan(&frame)
192 .map(|point| point.transpose())
193 .collect();
194 MatrixXx3::from_rows(&corrected).to_pyarray(py)
195}
196
197#[pyfunction]
200fn _kitti_seq_error(gt_poses: Vec<PyIsometry3>, results_poses: Vec<PyIsometry3>) -> (f64, f64) {
201 metrics::seq_error(>_poses, &results_poses)
202}
203
204#[pyfunction]
205fn _absolute_trajectory_error(
206 gt_poses: Vec<PyIsometry3>,
207 results_poses: Vec<PyIsometry3>,
208) -> (f64, f64) {
209 metrics::absolute_trajectory_error(>_poses, &results_poses)
210}
211
212#[pymodule]
213fn kiss_icp_pybind(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
214 runtime::init(runtime::SystemType::Library).map_err(|error| {
216 PyException::new_err(format!(
217 "failed to init {name}: {error}",
218 name = env!("CARGO_CRATE_NAME"),
219 ))
220 })?;
221
222 m.add_function(wrap_pyfunction!(_Vector3dVector, m)?)?;
223
224 m.add_class::<_VoxelHashMap>()?;
225
226 m.add_function(wrap_pyfunction!(_register_point_cloud, m)?)?;
228
229 m.add_class::<_AdaptiveThreshold>()?;
231
232 m.add_function(wrap_pyfunction!(_deskew_scan, m)?)?;
234
235 m.add_function(wrap_pyfunction!(_voxel_down_sample, m)?)?;
237 m.add_function(wrap_pyfunction!(_preprocess, m)?)?;
238 m.add_function(wrap_pyfunction!(_correct_kitti_scan, m)?)?;
239
240 m.add_function(wrap_pyfunction!(_kitti_seq_error, m)?)?;
242 m.add_function(wrap_pyfunction!(_absolute_trajectory_error, m)?)?;
243
244 Ok(())
245}