1use std::{ops::RangeBounds, ptr::NonNull};
2
3use crate::{SharedSubgrid, SimdVector};
4
5#[derive(Debug)]
7pub struct MutableSubgrid<'g, V = f32> {
8 ptr: NonNull<V>,
9 split_base: Option<NonNull<()>>,
10 width: usize,
11 height: usize,
12 stride: usize,
13 _marker: std::marker::PhantomData<&'g mut [V]>,
14}
15
16unsafe impl<'g, V> Send for MutableSubgrid<'g, V> where &'g mut [V]: Send {}
18unsafe impl<'g, V> Sync for MutableSubgrid<'g, V> where &'g mut [V]: Sync {}
19
20impl<'g, V> From<&'g mut crate::AlignedGrid<V>> for MutableSubgrid<'g, V> {
21 fn from(grid: &'g mut crate::AlignedGrid<V>) -> Self {
22 let width = grid.width();
23 let height = grid.height();
24 Self::from_buf(grid.buf_mut(), width, height, width)
25 }
26}
27
28impl<'g, V> MutableSubgrid<'g, V> {
29 pub unsafe fn new(ptr: NonNull<V>, width: usize, height: usize, stride: usize) -> Self {
38 assert!(width == 0 || width <= stride);
39 Self {
40 ptr,
41 split_base: None,
42 width,
43 height,
44 stride,
45 _marker: Default::default(),
46 }
47 }
48
49 pub fn empty() -> Self {
51 Self {
52 ptr: NonNull::dangling(),
53 split_base: None,
54 width: 0,
55 height: 0,
56 stride: 0,
57 _marker: Default::default(),
58 }
59 }
60
61 pub fn from_buf(buf: &'g mut [V], width: usize, height: usize, stride: usize) -> Self {
68 assert!(width <= stride);
69 if width == 0 || height == 0 {
70 assert_eq!(buf.len(), 0);
71 } else {
72 assert!(buf.len() >= stride * (height - 1) + width);
73 }
74 unsafe {
76 Self::new(
77 NonNull::new(buf.as_mut_ptr()).unwrap(),
78 width,
79 height,
80 stride,
81 )
82 }
83 }
84
85 #[inline]
87 pub fn width(&self) -> usize {
88 self.width
89 }
90
91 #[inline]
93 pub fn height(&self) -> usize {
94 self.height
95 }
96
97 #[inline]
98 fn get_ptr(&self, x: usize, y: usize) -> *mut V {
99 let width = self.width;
100 let height = self.height;
101 let Some(ptr) = self.try_get_ptr(x, y) else {
102 panic!("coordinate out of range: ({x}, {y}) not in {width}x{height}");
103 };
104
105 ptr
106 }
107
108 #[inline]
109 fn try_get_ptr(&self, x: usize, y: usize) -> Option<*mut V> {
110 if x >= self.width || y >= self.height {
111 return None;
112 }
113
114 Some(unsafe { self.get_ptr_unchecked(x, y) })
116 }
117
118 #[inline]
119 unsafe fn get_ptr_unchecked(&self, x: usize, y: usize) -> *mut V {
120 let offset = y * self.stride + x;
121 unsafe { self.ptr.as_ptr().add(offset) }
122 }
123
124 #[inline]
129 pub fn get_ref(&self, x: usize, y: usize) -> &V {
130 let width = self.width;
131 let height = self.height;
132 let Some(r) = self.try_get_ref(x, y) else {
133 panic!("coordinate out of range: ({x}, {y}) not in {width}x{height}");
134 };
135
136 r
137 }
138
139 #[inline]
141 pub fn try_get_ref(&self, x: usize, y: usize) -> Option<&V> {
142 self.try_get_ptr(x, y).map(|ptr| unsafe { &*ptr })
144 }
145
146 #[inline]
151 pub fn get_row(&self, row: usize) -> &[V] {
152 let height = self.height;
153 let Some(slice) = self.try_get_row(row) else {
154 panic!("row index out of range: height is {height} but index is {row}");
155 };
156
157 slice
158 }
159
160 #[inline]
162 pub fn try_get_row(&self, row: usize) -> Option<&[V]> {
163 if row >= self.height {
164 return None;
165 }
166
167 Some(unsafe {
169 let offset = row * self.stride;
170 let ptr = self.ptr.as_ptr().add(offset);
171 std::slice::from_raw_parts(ptr as *const _, self.width)
172 })
173 }
174
175 #[inline]
180 pub fn get_mut(&mut self, x: usize, y: usize) -> &mut V {
181 let width = self.width;
182 let height = self.height;
183 let Some(r) = self.try_get_mut(x, y) else {
184 panic!("coordinate out of range: ({x}, {y}) not in {width}x{height}");
185 };
186
187 r
188 }
189
190 #[inline]
193 pub fn try_get_mut(&mut self, x: usize, y: usize) -> Option<&mut V> {
194 self.try_get_ptr(x, y).map(|ptr| unsafe { &mut *ptr })
197 }
198
199 #[inline]
204 pub fn get_row_mut(&mut self, row: usize) -> &mut [V] {
205 let height = self.height;
206 let Some(slice) = self.try_get_row_mut(row) else {
207 panic!("row index out of range: height is {height} but index is {row}");
208 };
209
210 slice
211 }
212
213 #[inline]
215 pub fn try_get_row_mut(&mut self, row: usize) -> Option<&mut [V]> {
216 if row >= self.height {
217 return None;
218 }
219
220 Some(unsafe {
222 let offset = row * self.stride;
223 let ptr = self.ptr.as_ptr().add(offset);
224 std::slice::from_raw_parts_mut(ptr, self.width)
225 })
226 }
227
228 #[inline]
233 pub fn swap(&mut self, (ax, ay): (usize, usize), (bx, by): (usize, usize)) {
234 let a = self.get_ptr(ax, ay);
235 let b = self.get_ptr(bx, by);
236 if std::ptr::eq(a, b) {
237 return;
238 }
239
240 unsafe {
242 std::ptr::swap(a, b);
243 }
244 }
245
246 pub fn borrow_mut(&mut self) -> MutableSubgrid<V> {
248 unsafe { MutableSubgrid::new(self.ptr, self.width, self.height, self.stride) }
250 }
251
252 pub fn as_shared(&self) -> SharedSubgrid<V> {
254 unsafe { SharedSubgrid::new(self.ptr, self.width, self.height, self.stride) }
256 }
257
258 pub fn subgrid(
263 self,
264 range_x: impl RangeBounds<usize>,
265 range_y: impl RangeBounds<usize>,
266 ) -> MutableSubgrid<'g, V> {
267 use std::ops::Bound;
268
269 let left = match range_x.start_bound() {
270 Bound::Included(&v) => v,
271 Bound::Excluded(&v) => v + 1,
272 Bound::Unbounded => 0,
273 };
274 let right = match range_x.end_bound() {
275 Bound::Included(&v) => v + 1,
276 Bound::Excluded(&v) => v,
277 Bound::Unbounded => self.width,
278 };
279 let top = match range_y.start_bound() {
280 Bound::Included(&v) => v,
281 Bound::Excluded(&v) => v + 1,
282 Bound::Unbounded => 0,
283 };
284 let bottom = match range_y.end_bound() {
285 Bound::Included(&v) => v + 1,
286 Bound::Excluded(&v) => v,
287 Bound::Unbounded => self.height,
288 };
289
290 assert!(left <= right);
292 assert!(top <= bottom);
293 assert!(right <= self.width);
294 assert!(bottom <= self.height);
295
296 unsafe {
298 let base_ptr = NonNull::new(self.get_ptr_unchecked(left, top)).unwrap();
299 MutableSubgrid::new(base_ptr, right - left, bottom - top, self.stride)
300 }
301 }
302
303 pub fn split_horizontal(&mut self, x: usize) -> (MutableSubgrid<'_, V>, MutableSubgrid<'_, V>) {
308 assert!(x <= self.width);
309
310 let left_ptr = self.ptr;
311 let right_ptr = NonNull::new(unsafe { self.get_ptr_unchecked(x, 0) }).unwrap();
312 unsafe {
314 let split_base = self.split_base.unwrap_or(self.ptr.cast());
315 let mut left_grid = MutableSubgrid::new(left_ptr, x, self.height, self.stride);
316 let mut right_grid =
317 MutableSubgrid::new(right_ptr, self.width - x, self.height, self.stride);
318 left_grid.split_base = Some(split_base);
319 right_grid.split_base = Some(split_base);
320 (left_grid, right_grid)
321 }
322 }
323
324 pub fn split_horizontal_in_place(&mut self, x: usize) -> MutableSubgrid<'g, V> {
329 assert!(x <= self.width);
330
331 let right_width = self.width - x;
332 let right_ptr = NonNull::new(unsafe { self.get_ptr_unchecked(x, 0) }).unwrap();
333 unsafe {
335 let split_base = self.split_base.unwrap_or(self.ptr.cast());
336 self.width = x;
337 self.split_base = Some(split_base);
338 let mut right_grid =
339 MutableSubgrid::new(right_ptr, right_width, self.height, self.stride);
340 right_grid.split_base = Some(split_base);
341 right_grid
342 }
343 }
344
345 pub fn split_vertical(&mut self, y: usize) -> (MutableSubgrid<'_, V>, MutableSubgrid<'_, V>) {
350 assert!(y <= self.height);
351
352 let top_ptr = self.ptr;
353 let bottom_ptr = NonNull::new(unsafe { self.get_ptr_unchecked(0, y) }).unwrap();
354 unsafe {
356 let split_base = self.split_base.unwrap_or(self.ptr.cast());
357 let mut top_grid = MutableSubgrid::new(top_ptr, self.width, y, self.stride);
358 let mut bottom_grid =
359 MutableSubgrid::new(bottom_ptr, self.width, self.height - y, self.stride);
360 top_grid.split_base = Some(split_base);
361 bottom_grid.split_base = Some(split_base);
362 (top_grid, bottom_grid)
363 }
364 }
365
366 pub fn split_vertical_in_place(&mut self, y: usize) -> MutableSubgrid<'g, V> {
371 assert!(y <= self.height);
372
373 let bottom_height = self.height - y;
374 let bottom_ptr = NonNull::new(unsafe { self.get_ptr_unchecked(0, y) }).unwrap();
375 unsafe {
377 let split_base = self.split_base.unwrap_or(self.ptr.cast());
378 self.height = y;
379 self.split_base = Some(split_base);
380 let mut bottom_grid =
381 MutableSubgrid::new(bottom_ptr, self.width, bottom_height, self.stride);
382 bottom_grid.split_base = Some(split_base);
383 bottom_grid
384 }
385 }
386
387 pub fn merge_horizontal_in_place(&mut self, right: Self) {
398 assert!(self.split_base.is_some());
399 assert_eq!(self.split_base, right.split_base);
400 assert_eq!(self.stride, right.stride);
401 assert_eq!(self.height, right.height);
402 assert!(self.stride >= self.width + right.width);
403 unsafe {
404 assert!(std::ptr::eq(
405 self.get_ptr_unchecked(self.width, 0) as *const _,
406 right.ptr.as_ptr() as *const _,
407 ));
408 }
409
410 let right_width = right.width;
411 self.width += right_width;
412 }
413
414 pub fn merge_vertical_in_place(&mut self, bottom: Self) {
425 assert!(self.split_base.is_some());
426 assert_eq!(self.split_base, bottom.split_base);
427 assert_eq!(self.stride, bottom.stride);
428 assert_eq!(self.width, bottom.width);
429 unsafe {
430 assert!(std::ptr::eq(
431 self.get_ptr_unchecked(0, self.height) as *const _,
432 bottom.ptr.as_ptr() as *const _,
433 ));
434 }
435
436 let bottom_height = bottom.height;
437 self.height += bottom_height;
438 }
439}
440
441impl<V: Copy> MutableSubgrid<'_, V> {
442 #[inline]
447 pub fn get(&self, x: usize, y: usize) -> V {
448 *self.get_ref(x, y)
449 }
450}
451
452impl<'g, V> MutableSubgrid<'g, V> {
453 pub fn into_groups(
460 self,
461 group_width: usize,
462 group_height: usize,
463 ) -> Vec<MutableSubgrid<'g, V>> {
464 assert!(
465 group_width > 0 && group_height > 0,
466 "expected group width and height to be nonzero, got width = {group_width}, height = {group_height}"
467 );
468
469 let num_cols = self.width.div_ceil(group_width);
470 let num_rows = self.height.div_ceil(group_height);
471 self.into_groups_with_fixed_count(group_width, group_height, num_cols, num_rows)
472 }
473
474 pub fn into_groups_with_fixed_count(
481 self,
482 group_width: usize,
483 group_height: usize,
484 num_cols: usize,
485 num_rows: usize,
486 ) -> Vec<MutableSubgrid<'g, V>> {
487 let MutableSubgrid {
488 ptr,
489 split_base,
490 width,
491 height,
492 stride,
493 ..
494 } = self;
495 let split_base = split_base.unwrap_or(ptr.cast());
496
497 let mut groups = Vec::with_capacity(num_cols * num_rows);
498 for gy in 0..num_rows {
499 let y = (gy * group_height).min(height);
500 let gh = (height - y).min(group_height);
501 let row_ptr = unsafe { ptr.as_ptr().add(y * stride) };
502 for gx in 0..num_cols {
503 let x = (gx * group_width).min(width);
504 let gw = (width - x).min(group_width);
505 let ptr = unsafe { row_ptr.add(x) };
506
507 let mut grid =
508 unsafe { MutableSubgrid::new(NonNull::new(ptr).unwrap(), gw, gh, stride) };
509 grid.split_base = Some(split_base);
510 groups.push(grid);
511 }
512 }
513
514 groups
515 }
516}
517
518impl MutableSubgrid<'_, f32> {
519 pub fn as_vectored<V: SimdVector>(&mut self) -> Option<MutableSubgrid<'_, V>> {
525 assert!(
526 V::available(),
527 "Vector type `{}` is not supported by current CPU",
528 std::any::type_name::<V>()
529 );
530
531 let mask = V::SIZE - 1;
532 let align_mask = std::mem::align_of::<V>() - 1;
533
534 (self.ptr.as_ptr() as usize & align_mask == 0
535 && self.width & mask == 0
536 && self.stride & mask == 0)
537 .then(|| MutableSubgrid {
538 ptr: self.ptr.cast::<V>(),
539 split_base: self.split_base,
540 width: self.width / V::SIZE,
541 height: self.height,
542 stride: self.stride / V::SIZE,
543 _marker: Default::default(),
544 })
545 }
546}
547
548impl<'g> MutableSubgrid<'g, f32> {
549 pub fn into_i32(self) -> MutableSubgrid<'g, i32> {
551 MutableSubgrid {
553 ptr: self.ptr.cast(),
554 split_base: self.split_base,
555 width: self.width,
556 height: self.height,
557 stride: self.stride,
558 _marker: Default::default(),
559 }
560 }
561}