1use super::multi_jagged::split_at_mut_many;
20use crate::geometry::{Mbr, PointND};
21
22use nalgebra::allocator::Allocator;
23use nalgebra::ArrayStorage;
24use nalgebra::Const;
25use nalgebra::DefaultAllocator;
26use nalgebra::DimDiff;
27use nalgebra::DimSub;
28use nalgebra::ToTypenum;
29use rayon::prelude::*;
30
31use std::cmp::Ordering;
32use std::sync::atomic::{self, AtomicPtr};
33
34type HashType = u128;
38const HASH_TYPE_MAX: HashType = std::u128::MAX;
39
40fn z_curve_partition<const D: usize>(
41 partition: &mut [usize],
42 points: &[PointND<D>],
43 part_count: usize,
44 order: u32,
45) where
46 Const<D>: DimSub<Const<1>> + ToTypenum,
47 DefaultAllocator: Allocator<f64, Const<D>, Const<D>, Buffer = ArrayStorage<f64, D, D>>
48 + Allocator<f64, DimDiff<Const<D>, Const<1>>>,
49{
50 let max_order = (HASH_TYPE_MAX as f64).log(f64::from(2u32.pow(D as u32))) as u32;
51 assert!(
52 order <= max_order,
53 "Cannot use the z-curve partition algorithm with an order > {} because it would currently overflow hashes capacity",
54 max_order,
55 );
56
57 let mbr = Mbr::from_points(points);
59
60 let mut permutation: Vec<_> = (0..points.len()).into_par_iter().collect();
61
62 z_curve_partition_recurse(points, order, &mbr, &mut permutation);
64
65 let points_per_partition = points.len() / part_count;
66 let remainder = points.len() % part_count;
67
68 let atomic_handle = AtomicPtr::from(partition.as_mut_ptr());
69
70 let threshold_idx = (points_per_partition + 1) * remainder;
77 permutation[..threshold_idx]
78 .par_chunks(points_per_partition + 1)
79 .chain(permutation[threshold_idx..].par_chunks(points_per_partition))
80 .enumerate()
81 .for_each(|(id, chunk)| {
82 let ptr = atomic_handle.load(atomic::Ordering::Relaxed);
83 for idx in chunk {
84 unsafe { std::ptr::write(ptr.add(*idx), id) }
85 }
86 });
87}
88
89fn z_curve_partition_recurse<const D: usize>(
91 points: &[PointND<D>],
92 order: u32,
93 mbr: &Mbr<D>,
94 permu: &mut [usize],
95) {
96 if order == 0 || permu.len() <= 1 {
98 return;
99 }
100
101 let regions = points
104 .par_iter()
105 .map(|p| mbr.region(p).unwrap_or(0))
106 .collect::<Vec<_>>();
107
108 permu.par_sort_unstable_by_key(|idx| regions[*idx] as u8);
110
111 let mut split_positions = (1..2usize.pow(D as u32)).collect::<Vec<_>>();
117 for n in split_positions.iter_mut() {
118 *n = permu
119 .binary_search_by(|idx| {
120 if (regions[*idx] as u8) < *n as u8 {
121 Ordering::Less
122 } else {
123 Ordering::Greater
124 }
125 })
126 .unwrap_err();
127 }
128
129 let slices = split_at_mut_many(permu, &split_positions);
130 slices.into_par_iter().enumerate().for_each(|(i, slice)| {
131 z_curve_partition_recurse(points, order - 1, &mbr.sub_mbr(i as u32), slice);
132 })
133}
134
135#[allow(unused)]
137pub(crate) fn z_curve_reorder<const D: usize>(points: &[PointND<D>], order: u32) -> Vec<usize>
138where
139 Const<D>: DimSub<Const<1>>,
140 DefaultAllocator: Allocator<f64, Const<D>, Const<D>, Buffer = ArrayStorage<f64, D, D>>
141 + Allocator<f64, DimDiff<Const<D>, Const<1>>>,
142{
143 let max_order = (HASH_TYPE_MAX as f64).log(f64::from(2u32.pow(D as u32))) as u32;
144 assert!(
145 order <= max_order,
146 "Cannot use the z-curve partition algorithm with an order > {} because it would currently overflow hashes capacity",
147 max_order,
148 );
149
150 let mut permu: Vec<_> = (0..points.len()).into_par_iter().collect();
151 z_curve_reorder_permu(points, permu.as_mut_slice(), order);
152 permu
153}
154
155#[allow(unused)]
158pub(crate) fn z_curve_reorder_permu<const D: usize>(
159 points: &[PointND<D>],
160 permu: &mut [usize],
161 order: u32,
162) where
163 Const<D>: DimSub<Const<1>>,
164 DefaultAllocator: Allocator<f64, Const<D>, Const<D>, Buffer = ArrayStorage<f64, D, D>>
165 + Allocator<f64, DimDiff<Const<D>, Const<1>>>,
166{
167 let mbr = Mbr::from_points(points);
168 let hashes = permu
169 .par_iter()
170 .map(|idx| compute_hash(&points[*idx], order, &mbr))
171 .collect::<Vec<_>>();
172
173 permu.par_sort_by_key(|idx| hashes[*idx]);
174}
175
176fn compute_hash<const D: usize>(point: &PointND<D>, order: u32, mbr: &Mbr<D>) -> HashType {
177 let current_hash = mbr
178 .region(point)
179 .expect("Cannot compute the z-hash of a point outside of the current Mbr.");
180
181 if order == 0 {
182 HashType::from(current_hash)
183 } else {
184 (2_u128.pow(D as u32)).pow(order) * HashType::from(current_hash)
185 + compute_hash(point, order - 1, &mbr.sub_mbr(current_hash))
186 }
187}
188
189pub struct ZCurve {
223 pub part_count: usize,
224 pub order: u32,
225}
226
227impl<'a, const D: usize> crate::Partition<&'a [PointND<D>]> for ZCurve
228where
229 Const<D>: DimSub<Const<1>> + ToTypenum,
230 DefaultAllocator: Allocator<f64, Const<D>, Const<D>, Buffer = ArrayStorage<f64, D, D>>
231 + Allocator<f64, DimDiff<Const<D>, Const<1>>>,
232{
233 type Metadata = ();
234 type Error = std::convert::Infallible;
235
236 fn partition(
237 &mut self,
238 part_ids: &mut [usize],
239 points: &'a [PointND<D>],
240 ) -> Result<Self::Metadata, Self::Error> {
241 z_curve_partition(part_ids, points, self.part_count, self.order);
242 Ok(())
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use crate::geometry::Point2D;
250
251 #[test]
252 fn test_partition() {
253 let points = [
254 Point2D::from([0., 0.]),
255 Point2D::from([20., 10.]),
256 Point2D::from([0., 10.]),
257 Point2D::from([20., 0.]),
258 Point2D::from([14., 7.]),
259 Point2D::from([4., 7.]),
260 Point2D::from([14., 2.]),
261 Point2D::from([4., 2.]),
262 ];
263
264 let mut ids = [0; 8];
265 z_curve_partition(&mut ids, &points, 4, 1);
266 for id in ids {
267 println!("{}", id);
268 }
269 assert_eq!(ids[0], ids[7]);
270 assert_eq!(ids[1], ids[4]);
271 assert_eq!(ids[2], ids[5]);
272 assert_eq!(ids[3], ids[6]);
273 }
274}