1use reborrow::{Reborrow, ReborrowMut};
2use serde::{Deserialize, Serialize};
3use std::marker::PhantomData;
4use std::ops::{Bound, Range, RangeBounds};
5use std::slice::SliceIndex;
6
7#[derive(Clone, Debug, Default, Serialize, Deserialize, datasize::DataSize)]
9pub struct Image {
10 data: Vec<f64>,
11 channels: usize,
12}
13
14impl Image {
15 pub fn new(channels: usize, nodes: usize) -> Self {
17 Self {
18 data: vec![0.0; channels * nodes],
19 channels,
20 }
21 }
22
23 pub fn reinit(&mut self, channels: usize, nodes: usize) {
24 self.data.resize(channels * nodes, 0.0);
25 self.channels = channels;
26 }
27
28 pub fn resize(&mut self, nodes: usize) {
29 self.reinit(self.channels, nodes);
30 }
31
32 pub fn num_nodes(&self) -> usize {
33 if self.channels == 0 {
34 return 0;
35 }
36
37 self.data.len() / self.channels
38 }
39
40 pub fn is_empty(&self) -> bool {
41 self.num_nodes() == 0 || self.num_channels() == 0
42 }
43
44 pub fn from_storage(data: Vec<f64>, channels: usize) -> Self {
46 debug_assert!(data.len() % channels == 0);
47 Self { data, channels }
48 }
49
50 pub fn into_storage(self) -> Vec<f64> {
52 self.data
53 }
54
55 pub fn storage(&self) -> &[f64] {
56 &self.data
57 }
58
59 pub fn storage_mut(&mut self) -> &mut [f64] {
60 &mut self.data
61 }
62
63 pub fn num_channels(&self) -> usize {
64 self.channels
65 }
66
67 pub fn channels(&self) -> Range<usize> {
68 0..self.channels
69 }
70
71 pub fn channel(&self, channel: usize) -> &[f64] {
72 let stride = self.data.len() / self.channels;
73 &self.data[stride * channel..stride * (channel + 1)]
74 }
75
76 pub fn channel_mut(&mut self, channel: usize) -> &mut [f64] {
77 let stride = self.data.len() / self.channels;
78 &mut self.data[stride * channel..stride * (channel + 1)]
79 }
80
81 pub fn as_ref(&self) -> ImageRef<'_> {
82 ImageRef::from_storage(&self.data, self.channels)
83 }
84
85 pub fn as_mut(&mut self) -> ImageMut<'_> {
86 ImageMut::from_storage(&mut self.data, self.channels)
87 }
88
89 pub fn slice<R>(&self, range: R) -> ImageRef<'_>
90 where
91 R: RangeBounds<usize> + SliceIndex<[f64], Output = [f64]> + Clone,
92 {
93 let bounds = bounds_to_range(self.num_nodes(), range);
94 let length = bounds.end - bounds.start;
95
96 ImageRef {
97 ptr: self.data.as_ptr(),
98 total: self.data.len(),
99 offset: bounds.start,
100 length,
101 channels: self.channels,
102 _marker: PhantomData,
103 }
104 }
105
106 pub fn slice_mut<R>(&mut self, range: R) -> ImageMut<'_>
107 where
108 R: RangeBounds<usize> + SliceIndex<[f64], Output = [f64]> + Clone,
109 {
110 let bounds = bounds_to_range(self.num_nodes(), range);
111 let length = bounds.end - bounds.start;
112
113 ImageMut {
114 ptr: self.data.as_mut_ptr(),
115 total: self.data.len(),
116 offset: bounds.start,
117 length,
118 channels: self.channels,
119 _marker: PhantomData,
120 }
121 }
122}
123
124fn bounds_to_range<R>(total: usize, range: R) -> Range<usize>
126where
127 R: RangeBounds<usize>,
128{
129 let start_inc = match range.start_bound() {
130 Bound::Included(&i) => i,
131 Bound::Excluded(&i) => i + 1,
132 Bound::Unbounded => 0,
133 };
134
135 let end_exc = match range.end_bound() {
136 Bound::Included(&i) => i + 1,
137 Bound::Excluded(&i) => i,
138 Bound::Unbounded => total,
139 };
140
141 start_inc..end_exc
142}
143
144#[derive(Clone, Copy)]
146pub struct ImageRef<'a> {
147 ptr: *const f64,
148 total: usize,
149 offset: usize,
150 length: usize,
151 channels: usize,
152 _marker: PhantomData<&'a ()>,
153}
154
155impl<'a> ImageRef<'a> {
156 pub fn empty() -> Self {
157 Self::from_storage(&[], 0)
158 }
159
160 pub fn from_storage(data: &'a [f64], channels: usize) -> Self {
162 let mut length = 0;
163
164 if channels != 0 {
165 assert!(data.len() % channels == 0);
166 length = data.len() / channels;
167 }
168
169 Self {
170 ptr: data.as_ptr(),
171 total: data.len(),
172 offset: 0,
173 length,
174 channels,
175 _marker: PhantomData,
176 }
177 }
178
179 pub fn num_nodes(&self) -> usize {
181 self.length
182 }
183
184 pub fn is_empty(&self) -> bool {
189 self.length == 0 || self.channels == 0
190 }
191
192 pub fn num_channels(&self) -> usize {
193 self.channels
194 }
195
196 pub fn channels(&self) -> Range<usize> {
197 0..self.channels
198 }
199
200 fn stride(&self) -> usize {
201 debug_assert!(self.channels >= 1);
202 self.total / self.channels
203 }
204
205 pub fn split_channels(self, split: usize) -> (ImageRef<'a>, ImageRef<'a>) {
206 assert!(split <= self.channels);
207
208 let left_channels = split;
209 let right_channels = self.channels - split;
210
211 let ptr = self.ptr;
212 let length = self.length;
213 let offset = self.offset;
214
215 let left_total = left_channels * self.stride();
216 let right_total = right_channels * self.stride();
217
218 debug_assert_eq!(left_total + right_total, self.total);
219
220 let left_ptr = ptr;
221 let right_ptr = unsafe { ptr.add(left_total) };
222
223 (
224 ImageRef {
225 ptr: left_ptr,
226 total: left_total,
227 offset,
228 length,
229 channels: left_channels,
230 _marker: PhantomData,
231 },
232 ImageRef {
233 ptr: right_ptr,
234 total: right_total,
235 offset,
236 length,
237 channels: right_channels,
238 _marker: PhantomData,
239 },
240 )
241 }
242
243 pub fn channel(&self, channel: usize) -> &[f64] {
245 debug_assert!(channel < self.num_channels());
246
247 unsafe {
248 std::slice::from_raw_parts(
249 self.ptr.add(self.stride() * channel + self.offset),
250 self.length,
251 )
252 }
253 }
254
255 pub fn slice<R>(&self, range: R) -> ImageRef<'_>
257 where
258 R: RangeBounds<usize> + SliceIndex<[f64], Output = [f64]> + Clone,
259 {
260 let bounds = bounds_to_range(self.length, range);
261 let length = bounds.end - bounds.start;
262
263 debug_assert!(self.channels == 0 || length <= self.length);
264
265 ImageRef {
266 ptr: self.ptr,
267 total: self.total,
268 offset: self.offset + bounds.start,
269 length,
270 channels: self.channels,
271 _marker: PhantomData,
272 }
273 }
274
275 pub fn to_owned(&self) -> Image {
276 let mut data = Vec::with_capacity(self.length * self.channels);
277
278 for channel in 0..self.channels {
279 data.extend_from_slice(self.channel(channel));
280 }
281
282 Image::from_storage(data, self.channels)
283 }
284}
285
286impl<'short> Reborrow<'short> for ImageRef<'_> {
287 type Target = ImageRef<'short>;
288
289 fn rb(&'short self) -> Self::Target {
290 ImageRef {
291 ptr: self.ptr,
292 total: self.total,
293 offset: self.offset,
294 length: self.length,
295 channels: self.channels,
296 _marker: PhantomData,
297 }
298 }
299}
300
301impl<'a> From<&'a [f64]> for ImageRef<'a> {
302 fn from(value: &'a [f64]) -> Self {
303 ImageRef {
304 ptr: value.as_ptr(),
305 total: value.len(),
306 offset: 0,
307 length: value.len(),
308 channels: 1,
309 _marker: PhantomData,
310 }
311 }
312}
313
314impl<'a> From<&'a mut [f64]> for ImageRef<'a> {
315 fn from(value: &'a mut [f64]) -> Self {
316 ImageRef {
317 ptr: value.as_ptr(),
318 total: value.len(),
319 offset: 0,
320 length: value.len(),
321 channels: 1,
322 _marker: PhantomData,
323 }
324 }
325}
326
327unsafe impl Send for ImageRef<'_> {}
328unsafe impl Sync for ImageRef<'_> {}
329
330pub struct ImageMut<'a> {
332 ptr: *mut f64,
333 total: usize,
334 offset: usize,
335 length: usize,
336 channels: usize,
337 _marker: PhantomData<&'a mut ()>,
338}
339
340impl<'a> ImageMut<'a> {
341 pub fn from_storage(data: &'a mut [f64], channels: usize) -> Self {
343 let mut length = 0;
344
345 if channels != 0 {
346 assert!(data.len() % channels == 0);
347 length = data.len() / channels;
348 }
349
350 Self {
351 ptr: data.as_mut_ptr(),
352 total: data.len(),
353 offset: 0,
354 length,
355 channels,
356 _marker: PhantomData,
357 }
358 }
359
360 pub fn num_nodes(&self) -> usize {
362 self.length
363 }
364
365 pub fn is_empty(&self) -> bool {
370 self.length == 0 || self.channels == 0
371 }
372
373 pub fn num_channels(&self) -> usize {
374 self.channels
375 }
376
377 pub fn channels(&self) -> Range<usize> {
378 0..self.channels
379 }
380
381 pub fn split_channels(self, split: usize) -> (ImageMut<'a>, ImageMut<'a>) {
382 assert!(split < self.channels);
383 let left_channels = split;
384 let right_channels = self.channels - split;
385
386 let ptr = self.ptr;
387 let length = self.length;
388 let offset = self.offset;
389
390 let left_total = left_channels * self.stride();
391 let right_total = right_channels * self.stride();
392
393 debug_assert_eq!(left_total + right_total, self.total);
394
395 let left_ptr = ptr;
396 let right_ptr = unsafe { ptr.add(left_total) };
397
398 (
399 ImageMut {
400 ptr: left_ptr,
401 total: left_total,
402 offset,
403 length,
404 channels: left_channels,
405 _marker: PhantomData,
406 },
407 ImageMut {
408 ptr: right_ptr,
409 total: right_total,
410 offset,
411 length,
412 channels: right_channels,
413 _marker: PhantomData,
414 },
415 )
416 }
417
418 fn stride(&self) -> usize {
419 debug_assert!(self.channels >= 1);
420 self.total / self.channels
421 }
422
423 pub fn channel(&self, channel: usize) -> &[f64] {
425 debug_assert!(channel < self.num_channels());
426
427 unsafe {
428 std::slice::from_raw_parts(
429 self.ptr.add(self.stride() * channel + self.offset),
430 self.length,
431 )
432 }
433 }
434
435 pub fn channel_mut(&mut self, channel: usize) -> &mut [f64] {
437 debug_assert!(channel < self.num_channels());
438
439 unsafe {
440 std::slice::from_raw_parts_mut(
441 self.ptr.add(self.stride() * channel + self.offset),
442 self.length,
443 )
444 }
445 }
446
447 pub fn slice<R>(&self, range: R) -> ImageRef<'_>
449 where
450 R: RangeBounds<usize> + SliceIndex<[f64], Output = [f64]> + Clone,
451 {
452 let bounds = bounds_to_range(self.length, range);
453 let length = bounds.end - bounds.start;
454
455 debug_assert!(self.channels == 0 || length <= self.length);
456
457 ImageRef {
458 ptr: self.ptr,
459 total: self.total,
460 offset: self.offset + bounds.start,
461 length,
462 channels: self.channels,
463 _marker: PhantomData,
464 }
465 }
466
467 pub fn slice_mut<R>(&mut self, range: R) -> ImageMut<'_>
469 where
470 R: RangeBounds<usize> + SliceIndex<[f64], Output = [f64]> + Clone,
471 {
472 let bounds = bounds_to_range(self.length, range);
473 let length = bounds.end - bounds.start;
474
475 debug_assert!(self.channels == 0 || length <= self.length);
476
477 ImageMut {
478 ptr: self.ptr,
479 total: self.total,
480 offset: self.offset + bounds.start,
481 length,
482 channels: self.channels,
483 _marker: PhantomData,
484 }
485 }
486
487 pub fn to_owned(&self) -> Image {
488 let mut data = Vec::with_capacity(self.length * self.channels);
489
490 for channel in 0..self.channels {
491 data.extend_from_slice(self.channel(channel));
492 }
493
494 Image::from_storage(data, self.channels)
495 }
496}
497
498impl<'a> From<&'a mut [f64]> for ImageMut<'a> {
499 fn from(value: &'a mut [f64]) -> Self {
500 ImageMut {
501 ptr: value.as_mut_ptr(),
502 total: value.len(),
503 offset: 0,
504 length: value.len(),
505 channels: 1,
506 _marker: PhantomData,
507 }
508 }
509}
510
511impl<'short> Reborrow<'short> for ImageMut<'_> {
512 type Target = ImageRef<'short>;
513
514 fn rb(&'short self) -> Self::Target {
515 ImageRef {
516 ptr: self.ptr,
517 total: self.total,
518 offset: self.offset,
519 length: self.length,
520 channels: self.channels,
521 _marker: PhantomData,
522 }
523 }
524}
525
526impl<'short> ReborrowMut<'short> for ImageMut<'_> {
527 type Target = ImageMut<'short>;
528
529 fn rb_mut(&'short mut self) -> Self::Target {
530 ImageMut {
531 ptr: self.ptr,
532 total: self.total,
533 offset: self.offset,
534 length: self.length,
535 channels: self.channels,
536 _marker: PhantomData,
537 }
538 }
539}
540
541unsafe impl Send for ImageMut<'_> {}
542unsafe impl Sync for ImageMut<'_> {}
543
544#[derive(Debug, Clone)]
546pub struct ImageShared<'a> {
547 ptr: *mut f64,
548 total: usize,
549 offset: usize,
550 length: usize,
551 channels: usize,
552 _marker: PhantomData<&'a mut ()>,
553}
554
555impl ImageShared<'_> {
556 pub fn num_nodes(&self) -> usize {
558 self.length
559 }
560
561 pub fn is_empty(&self) -> bool {
566 self.length == 0 || self.num_channels() == 0
567 }
568
569 pub fn num_channels(&self) -> usize {
570 self.channels
571 }
572
573 pub fn channels(&self) -> Range<usize> {
574 0..self.channels
575 }
576
577 fn stride(&self) -> usize {
578 debug_assert!(self.channels >= 1);
579 self.total / self.channels
580 }
581
582 pub unsafe fn channel(&self, channel: usize) -> &[f64] {
584 debug_assert!(channel < self.num_channels());
585
586 unsafe {
587 std::slice::from_raw_parts(
588 self.ptr.add(self.stride() * channel + self.offset),
589 self.length,
590 )
591 }
592 }
593
594 pub unsafe fn channel_mut(&self, channel: usize) -> &mut [f64] {
596 debug_assert!(channel < self.num_channels());
597
598 unsafe {
599 std::slice::from_raw_parts_mut(
600 self.ptr.add(self.stride() * channel + self.offset),
601 self.length,
602 )
603 }
604 }
605
606 pub unsafe fn slice<R>(&self, range: R) -> ImageRef<'_>
611 where
612 R: RangeBounds<usize> + SliceIndex<[f64], Output = [f64]> + Clone,
613 {
614 let bounds = bounds_to_range(self.length, range);
615 let length = bounds.end - bounds.start;
616
617 debug_assert!(self.channels == 0 || length <= self.length);
618
619 ImageRef {
620 ptr: self.ptr,
621 total: self.total,
622 offset: self.offset + bounds.start,
623 length,
624 channels: self.channels,
625 _marker: PhantomData,
626 }
627 }
628
629 pub unsafe fn slice_mut<R>(&self, range: R) -> ImageMut<'_>
634 where
635 R: RangeBounds<usize> + SliceIndex<[f64], Output = [f64]> + Clone,
636 {
637 let bounds = bounds_to_range(self.length, range);
638 let length = bounds.end - bounds.start;
639
640 debug_assert!(self.channels == 0 || length <= self.length);
641
642 ImageMut {
643 ptr: self.ptr,
644 total: self.total,
645 offset: self.offset + bounds.start,
646 length,
647 channels: self.channels,
648 _marker: PhantomData,
649 }
650 }
651}
652
653impl<'a> From<ImageMut<'a>> for ImageShared<'a> {
654 fn from(value: ImageMut<'a>) -> Self {
655 ImageShared {
656 ptr: value.ptr,
657 total: value.total,
658 offset: value.offset,
659 length: value.length,
660 channels: value.channels,
661 _marker: PhantomData,
662 }
663 }
664}
665
666unsafe impl Send for ImageShared<'_> {}
667unsafe impl Sync for ImageShared<'_> {}
668
669#[cfg(test)]
670mod tests {
671 use super::*;
672
673 const FIRST_CH: usize = 0;
674 const SECOND_CH: usize = 1;
675 const THIRD_CH: usize = 2;
676
677 #[test]
679 fn basic() {
680 let mut fields = Image::new(3, 3);
681
682 {
683 let shared: ImageShared = fields.as_mut().into();
684 let mut slice = unsafe { shared.slice_mut(1..2) };
685
686 slice.channel_mut(FIRST_CH).fill(1.0);
687 slice.channel_mut(SECOND_CH).fill(2.0);
688 slice.channel_mut(THIRD_CH).fill(3.0);
689 }
690
691 let buffer = fields.storage();
692
693 assert_eq!(buffer, &[0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0]);
694
695 let empty = Image::new(0, 0);
696 assert!(empty.is_empty());
697 }
698
699 #[test]
701 fn pair() {
702 let mut data = Image::new(5, 10);
703
704 data.channel_mut(0).fill(0.0);
705 data.channel_mut(1).fill(1.0);
706 data.channel_mut(2).fill(2.0);
707 data.channel_mut(3).fill(3.0);
708 data.channel_mut(4).fill(4.0);
709
710 {
711 let data = data.as_ref();
712 let (left, right) = data.split_channels(3);
713 assert_eq!(left.num_channels(), 3);
714 assert_eq!(right.num_channels(), 2);
715
716 assert!(left.channel(0).iter().all(|v| *v == 0.0));
717 assert!(left.channel(1).iter().all(|v| *v == 1.0));
718 assert!(left.channel(2).iter().all(|v| *v == 2.0));
719 assert!(right.channel(0).iter().all(|v| *v == 3.0));
720 assert!(right.channel(1).iter().all(|v| *v == 4.0));
721 }
722
723 {
724 let slice: ImageMut<'_> = ImageMut::from_storage(data.storage_mut(), 5);
725 let (left, right) = slice.split_channels(3);
726
727 assert!(left.channel(0).iter().all(|v| *v == 0.0));
728 assert!(left.channel(1).iter().all(|v| *v == 1.0));
729 assert!(left.channel(2).iter().all(|v| *v == 2.0));
730 assert!(right.channel(0).iter().all(|v| *v == 3.0));
731 assert!(right.channel(1).iter().all(|v| *v == 4.0));
732 }
733
734 let data = (0..15).map(|i| i as f64).collect::<Vec<_>>();
735 let image = Image::from_storage(data, 3);
736
737 {
738 let image = image.as_ref();
739 let (left, right) = image.split_channels(2);
740
741 assert_eq!(left.channel(0), &[0.0, 1.0, 2.0, 3.0, 4.0]);
742 assert_eq!(left.channel(1), &[5.0, 6.0, 7.0, 8.0, 9.0]);
743 assert_eq!(right.channel(0), &[10.0, 11.0, 12.0, 13.0, 14.0]);
744 }
745 {
746 let image = image.as_ref();
747 let (slice1, slice2) = image.slice(2..4).split_channels(2);
748 assert_eq!(slice1.channel(0), &[2.0, 3.0]);
749 assert_eq!(slice1.channel(1), &[7.0, 8.0]);
750 assert_eq!(slice2.channel(0), &[12.0, 13.0]);
751 }
752 }
753}