Skip to main content

aeon_tk/image/
mod.rs

1use reborrow::{Reborrow, ReborrowMut};
2use serde::{Deserialize, Serialize};
3use std::marker::PhantomData;
4use std::ops::{Bound, Range, RangeBounds};
5use std::slice::SliceIndex;
6
7/// Several fields of data spread among
8#[derive(Clone, Debug, Default, Serialize, Deserialize, datasize::DataSize)]
9pub struct Image {
10    data: Vec<f64>,
11    channels: usize,
12}
13
14impl Image {
15    /// Allocates a new image with the specified number of channels and nodes
16    pub fn new(channels: usize, nodes: usize) -> Self {
17        Self {
18            data: vec![0.0; channels * nodes],
19            channels,
20        }
21    }
22
23    pub fn reinit(&mut self, channels: usize, nodes: usize) {
24        self.data.resize(channels * nodes, 0.0);
25        self.channels = channels;
26    }
27
28    pub fn resize(&mut self, nodes: usize) {
29        self.reinit(self.channels, nodes);
30    }
31
32    pub fn num_nodes(&self) -> usize {
33        if self.channels == 0 {
34            return 0;
35        }
36
37        self.data.len() / self.channels
38    }
39
40    pub fn is_empty(&self) -> bool {
41        self.num_nodes() == 0 || self.num_channels() == 0
42    }
43
44    /// Constructs a new system vector from the given data. `data.len()` must be divisible by `field_count::<Label>()`.
45    pub fn from_storage(data: Vec<f64>, channels: usize) -> Self {
46        debug_assert!(data.len() % channels == 0);
47        Self { data, channels }
48    }
49
50    /// Transforms a system vector back into a linear vector
51    pub fn into_storage(self) -> Vec<f64> {
52        self.data
53    }
54
55    pub fn storage(&self) -> &[f64] {
56        &self.data
57    }
58
59    pub fn storage_mut(&mut self) -> &mut [f64] {
60        &mut self.data
61    }
62
63    pub fn num_channels(&self) -> usize {
64        self.channels
65    }
66
67    pub fn channels(&self) -> Range<usize> {
68        0..self.channels
69    }
70
71    pub fn channel(&self, channel: usize) -> &[f64] {
72        let stride = self.data.len() / self.channels;
73        &self.data[stride * channel..stride * (channel + 1)]
74    }
75
76    pub fn channel_mut(&mut self, channel: usize) -> &mut [f64] {
77        let stride = self.data.len() / self.channels;
78        &mut self.data[stride * channel..stride * (channel + 1)]
79    }
80
81    pub fn as_ref(&self) -> ImageRef<'_> {
82        ImageRef::from_storage(&self.data, self.channels)
83    }
84
85    pub fn as_mut(&mut self) -> ImageMut<'_>  {
86        ImageMut::from_storage(&mut self.data, self.channels)
87    }
88
89    pub fn slice<R>(&self, range: R) -> ImageRef<'_>
90    where
91        R: RangeBounds<usize> + SliceIndex<[f64], Output = [f64]> + Clone,
92    {
93        let bounds = bounds_to_range(self.num_nodes(), range);
94        let length = bounds.end - bounds.start;
95
96        ImageRef {
97            ptr: self.data.as_ptr(),
98            total: self.data.len(),
99            offset: bounds.start,
100            length,
101            channels: self.channels,
102            _marker: PhantomData,
103        }
104    }
105
106    pub fn slice_mut<R>(&mut self, range: R) -> ImageMut<'_>
107    where
108        R: RangeBounds<usize> + SliceIndex<[f64], Output = [f64]> + Clone,
109    {
110        let bounds = bounds_to_range(self.num_nodes(), range);
111        let length = bounds.end - bounds.start;
112
113        ImageMut {
114            ptr: self.data.as_mut_ptr(),
115            total: self.data.len(),
116            offset: bounds.start,
117            length,
118            channels: self.channels,
119            _marker: PhantomData,
120        }
121    }
122}
123
124/// Converts genetic range to a concrete range type.
125fn bounds_to_range<R>(total: usize, range: R) -> Range<usize>
126where
127    R: RangeBounds<usize>,
128{
129    let start_inc = match range.start_bound() {
130        Bound::Included(&i) => i,
131        Bound::Excluded(&i) => i + 1,
132        Bound::Unbounded => 0,
133    };
134
135    let end_exc = match range.end_bound() {
136        Bound::Included(&i) => i + 1,
137        Bound::Excluded(&i) => i,
138        Bound::Unbounded => total,
139    };
140
141    start_inc..end_exc
142}
143
144/// Represents a subslice of an owned system vector.
145#[derive(Clone, Copy)]
146pub struct ImageRef<'a> {
147    ptr: *const f64,
148    total: usize,
149    offset: usize,
150    length: usize,
151    channels: usize,
152    _marker: PhantomData<&'a ()>,
153}
154
155impl<'a> ImageRef<'a> {
156    pub fn empty() -> Self {
157        Self::from_storage(&[], 0)
158    }
159
160    /// Builds a system slice from a contiguous chunk of data.
161    pub fn from_storage(data: &'a [f64], channels: usize) -> Self {
162        let mut length = 0;
163
164        if channels != 0 {
165            assert!(data.len() % channels == 0);
166            length = data.len() / channels;
167        }
168
169        Self {
170            ptr: data.as_ptr(),
171            total: data.len(),
172            offset: 0,
173            length,
174            channels,
175            _marker: PhantomData,
176        }
177    }
178
179    /// Returns the size of the system slice.
180    pub fn num_nodes(&self) -> usize {
181        self.length
182    }
183
184    // pub fn len(&self) -> usize {
185    //     self.num_nodes() * self.num_channels()
186    // }
187
188    pub fn is_empty(&self) -> bool {
189        self.length == 0 || self.channels == 0
190    }
191
192    pub fn num_channels(&self) -> usize {
193        self.channels
194    }
195
196    pub fn channels(&self) -> Range<usize> {
197        0..self.channels
198    }
199
200    fn stride(&self) -> usize {
201        debug_assert!(self.channels >= 1);
202        self.total / self.channels
203    }
204
205    pub fn split_channels(self, split: usize) -> (ImageRef<'a>, ImageRef<'a>) {
206        assert!(split <= self.channels);
207
208        let left_channels = split;
209        let right_channels = self.channels - split;
210
211        let ptr = self.ptr;
212        let length = self.length;
213        let offset = self.offset;
214
215        let left_total = left_channels * self.stride();
216        let right_total = right_channels * self.stride();
217
218        debug_assert_eq!(left_total + right_total, self.total);
219
220        let left_ptr = ptr;
221        let right_ptr = unsafe { ptr.add(left_total) };
222
223        (
224            ImageRef {
225                ptr: left_ptr,
226                total: left_total,
227                offset,
228                length,
229                channels: left_channels,
230                _marker: PhantomData,
231            },
232            ImageRef {
233                ptr: right_ptr,
234                total: right_total,
235                offset,
236                length,
237                channels: right_channels,
238                _marker: PhantomData,
239            },
240        )
241    }
242
243    /// Gets an immutable reference to the given field.
244    pub fn channel(&self, channel: usize) -> &[f64] {
245        debug_assert!(channel < self.num_channels());
246
247        unsafe {
248            std::slice::from_raw_parts(
249                self.ptr.add(self.stride() * channel + self.offset),
250                self.length,
251            )
252        }
253    }
254
255    /// Takes a subslice of the existing slice.
256    pub fn slice<R>(&self, range: R) -> ImageRef<'_>
257    where
258        R: RangeBounds<usize> + SliceIndex<[f64], Output = [f64]> + Clone,
259    {
260        let bounds = bounds_to_range(self.length, range);
261        let length = bounds.end - bounds.start;
262
263        debug_assert!(self.channels == 0 || length <= self.length);
264
265        ImageRef {
266            ptr: self.ptr,
267            total: self.total,
268            offset: self.offset + bounds.start,
269            length,
270            channels: self.channels,
271            _marker: PhantomData,
272        }
273    }
274
275    pub fn to_owned(&self) -> Image {
276        let mut data = Vec::with_capacity(self.length * self.channels);
277
278        for channel in 0..self.channels {
279            data.extend_from_slice(self.channel(channel));
280        }
281
282        Image::from_storage(data, self.channels)
283    }
284}
285
286impl<'short> Reborrow<'short> for ImageRef<'_> {
287    type Target = ImageRef<'short>;
288
289    fn rb(&'short self) -> Self::Target {
290        ImageRef {
291            ptr: self.ptr,
292            total: self.total,
293            offset: self.offset,
294            length: self.length,
295            channels: self.channels,
296            _marker: PhantomData,
297        }
298    }
299}
300
301impl<'a> From<&'a [f64]> for ImageRef<'a> {
302    fn from(value: &'a [f64]) -> Self {
303        ImageRef {
304            ptr: value.as_ptr(),
305            total: value.len(),
306            offset: 0,
307            length: value.len(),
308            channels: 1,
309            _marker: PhantomData,
310        }
311    }
312}
313
314impl<'a> From<&'a mut [f64]> for ImageRef<'a> {
315    fn from(value: &'a mut [f64]) -> Self {
316        ImageRef {
317            ptr: value.as_ptr(),
318            total: value.len(),
319            offset: 0,
320            length: value.len(),
321            channels: 1,
322            _marker: PhantomData,
323        }
324    }
325}
326
327unsafe impl Send for ImageRef<'_> {}
328unsafe impl Sync for ImageRef<'_> {}
329
330/// A mutable reference to an owned system.
331pub struct ImageMut<'a> {
332    ptr: *mut f64,
333    total: usize,
334    offset: usize,
335    length: usize,
336    channels: usize,
337    _marker: PhantomData<&'a mut ()>,
338}
339
340impl<'a> ImageMut<'a> {
341    /// Builds a mutable system slice from contiguous data.
342    pub fn from_storage(data: &'a mut [f64], channels: usize) -> Self {
343        let mut length = 0;
344
345        if channels != 0 {
346            assert!(data.len() % channels == 0);
347            length = data.len() / channels;
348        }
349
350        Self {
351            ptr: data.as_mut_ptr(),
352            total: data.len(),
353            offset: 0,
354            length,
355            channels,
356            _marker: PhantomData,
357        }
358    }
359
360    /// Returns the size of the system slice.
361    pub fn num_nodes(&self) -> usize {
362        self.length
363    }
364
365    // pub fn len(&self) -> usize {
366    //     self.length * self.channels
367    // }
368
369    pub fn is_empty(&self) -> bool {
370        self.length == 0 || self.channels == 0
371    }
372
373    pub fn num_channels(&self) -> usize {
374        self.channels
375    }
376
377    pub fn channels(&self) -> Range<usize> {
378        0..self.channels
379    }
380
381    pub fn split_channels(self, split: usize) -> (ImageMut<'a>, ImageMut<'a>) {
382        assert!(split < self.channels);
383        let left_channels = split;
384        let right_channels = self.channels - split;
385
386        let ptr = self.ptr;
387        let length = self.length;
388        let offset = self.offset;
389
390        let left_total = left_channels * self.stride();
391        let right_total = right_channels * self.stride();
392
393        debug_assert_eq!(left_total + right_total, self.total);
394
395        let left_ptr = ptr;
396        let right_ptr = unsafe { ptr.add(left_total) };
397
398        (
399            ImageMut {
400                ptr: left_ptr,
401                total: left_total,
402                offset,
403                length,
404                channels: left_channels,
405                _marker: PhantomData,
406            },
407            ImageMut {
408                ptr: right_ptr,
409                total: right_total,
410                offset,
411                length,
412                channels: right_channels,
413                _marker: PhantomData,
414            },
415        )
416    }
417
418    fn stride(&self) -> usize {
419        debug_assert!(self.channels >= 1);
420        self.total / self.channels
421    }
422
423    /// Gets an immutable reference to the given field.
424    pub fn channel(&self, channel: usize) -> &[f64] {
425        debug_assert!(channel < self.num_channels());
426
427        unsafe {
428            std::slice::from_raw_parts(
429                self.ptr.add(self.stride() * channel + self.offset),
430                self.length,
431            )
432        }
433    }
434
435    /// Retrieves a mutable slice of the given field.
436    pub fn channel_mut(&mut self, channel: usize) -> &mut [f64] {
437        debug_assert!(channel < self.num_channels());
438
439        unsafe {
440            std::slice::from_raw_parts_mut(
441                self.ptr.add(self.stride() * channel + self.offset),
442                self.length,
443            )
444        }
445    }
446
447    /// Takes a subslice of this slice.
448    pub fn slice<R>(&self, range: R) -> ImageRef<'_>
449    where
450        R: RangeBounds<usize> + SliceIndex<[f64], Output = [f64]> + Clone,
451    {
452        let bounds = bounds_to_range(self.length, range);
453        let length = bounds.end - bounds.start;
454
455        debug_assert!(self.channels == 0 || length <= self.length);
456
457        ImageRef {
458            ptr: self.ptr,
459            total: self.total,
460            offset: self.offset + bounds.start,
461            length,
462            channels: self.channels,
463            _marker: PhantomData,
464        }
465    }
466
467    /// Takes a mutable subslice of this slice.
468    pub fn slice_mut<R>(&mut self, range: R) -> ImageMut<'_>
469    where
470        R: RangeBounds<usize> + SliceIndex<[f64], Output = [f64]> + Clone,
471    {
472        let bounds = bounds_to_range(self.length, range);
473        let length = bounds.end - bounds.start;
474
475        debug_assert!(self.channels == 0 || length <= self.length);
476
477        ImageMut {
478            ptr: self.ptr,
479            total: self.total,
480            offset: self.offset + bounds.start,
481            length,
482            channels: self.channels,
483            _marker: PhantomData,
484        }
485    }
486
487    pub fn to_owned(&self) -> Image {
488        let mut data = Vec::with_capacity(self.length * self.channels);
489
490        for channel in 0..self.channels {
491            data.extend_from_slice(self.channel(channel));
492        }
493
494        Image::from_storage(data, self.channels)
495    }
496}
497
498impl<'a> From<&'a mut [f64]> for ImageMut<'a> {
499    fn from(value: &'a mut [f64]) -> Self {
500        ImageMut {
501            ptr: value.as_mut_ptr(),
502            total: value.len(),
503            offset: 0,
504            length: value.len(),
505            channels: 1,
506            _marker: PhantomData,
507        }
508    }
509}
510
511impl<'short> Reborrow<'short> for ImageMut<'_> {
512    type Target = ImageRef<'short>;
513
514    fn rb(&'short self) -> Self::Target {
515        ImageRef {
516            ptr: self.ptr,
517            total: self.total,
518            offset: self.offset,
519            length: self.length,
520            channels: self.channels,
521            _marker: PhantomData,
522        }
523    }
524}
525
526impl<'short> ReborrowMut<'short> for ImageMut<'_> {
527    type Target = ImageMut<'short>;
528
529    fn rb_mut(&'short mut self) -> Self::Target {
530        ImageMut {
531            ptr: self.ptr,
532            total: self.total,
533            offset: self.offset,
534            length: self.length,
535            channels: self.channels,
536            _marker: PhantomData,
537        }
538    }
539}
540
541unsafe impl Send for ImageMut<'_> {}
542unsafe impl Sync for ImageMut<'_> {}
543
544/// An unsafe pointer to a range of a system.
545#[derive(Debug, Clone)]
546pub struct ImageShared<'a> {
547    ptr: *mut f64,
548    total: usize,
549    offset: usize,
550    length: usize,
551    channels: usize,
552    _marker: PhantomData<&'a mut ()>,
553}
554
555impl ImageShared<'_> {
556    /// Returns the size of the system slice.
557    pub fn num_nodes(&self) -> usize {
558        self.length
559    }
560
561    // pub fn len(&self) -> usize {
562    //     self.length * self.channels
563    // }
564
565    pub fn is_empty(&self) -> bool {
566        self.length == 0 || self.num_channels() == 0
567    }
568
569    pub fn num_channels(&self) -> usize {
570        self.channels
571    }
572
573    pub fn channels(&self) -> Range<usize> {
574        0..self.channels
575    }
576
577    fn stride(&self) -> usize {
578        debug_assert!(self.channels >= 1);
579        self.total / self.channels
580    }
581
582    /// Retrieves an immutable slice to the given field.
583    pub unsafe fn channel(&self, channel: usize) -> &[f64] {
584        debug_assert!(channel < self.num_channels());
585
586        unsafe {
587            std::slice::from_raw_parts(
588                self.ptr.add(self.stride() * channel + self.offset),
589                self.length,
590            )
591        }
592    }
593
594    /// Retrieves a mutable slice of the given field.
595    pub unsafe fn channel_mut(&self, channel: usize) -> &mut [f64] {
596        debug_assert!(channel < self.num_channels());
597
598        unsafe {
599            std::slice::from_raw_parts_mut(
600                self.ptr.add(self.stride() * channel + self.offset),
601                self.length,
602            )
603        }
604    }
605
606    /// Retrieves an immutable reference to a slice of the system.
607    ///
608    /// # Safety
609    /// No other mutable refernces may refer to any element of this slice while it is alive.
610    pub unsafe fn slice<R>(&self, range: R) -> ImageRef<'_>
611    where
612        R: RangeBounds<usize> + SliceIndex<[f64], Output = [f64]> + Clone,
613    {
614        let bounds = bounds_to_range(self.length, range);
615        let length = bounds.end - bounds.start;
616
617        debug_assert!(self.channels == 0 || length <= self.length);
618
619        ImageRef {
620            ptr: self.ptr,
621            total: self.total,
622            offset: self.offset + bounds.start,
623            length,
624            channels: self.channels,
625            _marker: PhantomData,
626        }
627    }
628
629    /// Retrieves a mutable reference to a slice of the system.
630    ///
631    /// # Safety
632    /// No other refernces may refer to any element of this slice while it is alive.
633    pub unsafe fn slice_mut<R>(&self, range: R) -> ImageMut<'_>
634    where
635        R: RangeBounds<usize> + SliceIndex<[f64], Output = [f64]> + Clone,
636    {
637        let bounds = bounds_to_range(self.length, range);
638        let length = bounds.end - bounds.start;
639
640        debug_assert!(self.channels == 0 || length <= self.length);
641
642        ImageMut {
643            ptr: self.ptr,
644            total: self.total,
645            offset: self.offset + bounds.start,
646            length,
647            channels: self.channels,
648            _marker: PhantomData,
649        }
650    }
651}
652
653impl<'a> From<ImageMut<'a>> for ImageShared<'a> {
654    fn from(value: ImageMut<'a>) -> Self {
655        ImageShared {
656            ptr: value.ptr,
657            total: value.total,
658            offset: value.offset,
659            length: value.length,
660            channels: value.channels,
661            _marker: PhantomData,
662        }
663    }
664}
665
666unsafe impl Send for ImageShared<'_> {}
667unsafe impl Sync for ImageShared<'_> {}
668
669#[cfg(test)]
670mod tests {
671    use super::*;
672
673    const FIRST_CH: usize = 0;
674    const SECOND_CH: usize = 1;
675    const THIRD_CH: usize = 2;
676
677    /// Test of basic system functionality.
678    #[test]
679    fn basic() {
680        let mut fields = Image::new(3, 3);
681
682        {
683            let shared: ImageShared = fields.as_mut().into();
684            let mut slice = unsafe { shared.slice_mut(1..2) };
685
686            slice.channel_mut(FIRST_CH).fill(1.0);
687            slice.channel_mut(SECOND_CH).fill(2.0);
688            slice.channel_mut(THIRD_CH).fill(3.0);
689        }
690
691        let buffer = fields.storage();
692
693        assert_eq!(buffer, &[0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0]);
694
695        let empty = Image::new(0, 0);
696        assert!(empty.is_empty());
697    }
698
699    /// A simple test of creating composite (pair) systems, splitting them, and taking various references.
700    #[test]
701    fn pair() {
702        let mut data = Image::new(5, 10);
703
704        data.channel_mut(0).fill(0.0);
705        data.channel_mut(1).fill(1.0);
706        data.channel_mut(2).fill(2.0);
707        data.channel_mut(3).fill(3.0);
708        data.channel_mut(4).fill(4.0);
709
710        {
711            let data = data.as_ref();
712            let (left, right) = data.split_channels(3);
713            assert_eq!(left.num_channels(), 3);
714            assert_eq!(right.num_channels(), 2);
715
716            assert!(left.channel(0).iter().all(|v| *v == 0.0));
717            assert!(left.channel(1).iter().all(|v| *v == 1.0));
718            assert!(left.channel(2).iter().all(|v| *v == 2.0));
719            assert!(right.channel(0).iter().all(|v| *v == 3.0));
720            assert!(right.channel(1).iter().all(|v| *v == 4.0));
721        }
722
723        {
724            let slice: ImageMut<'_> = ImageMut::from_storage(data.storage_mut(), 5);
725            let (left, right) = slice.split_channels(3);
726
727            assert!(left.channel(0).iter().all(|v| *v == 0.0));
728            assert!(left.channel(1).iter().all(|v| *v == 1.0));
729            assert!(left.channel(2).iter().all(|v| *v == 2.0));
730            assert!(right.channel(0).iter().all(|v| *v == 3.0));
731            assert!(right.channel(1).iter().all(|v| *v == 4.0));
732        }
733
734        let data = (0..15).map(|i| i as f64).collect::<Vec<_>>();
735        let image = Image::from_storage(data, 3);
736
737        {
738            let image = image.as_ref();
739            let (left, right) = image.split_channels(2);
740
741            assert_eq!(left.channel(0), &[0.0, 1.0, 2.0, 3.0, 4.0]);
742            assert_eq!(left.channel(1), &[5.0, 6.0, 7.0, 8.0, 9.0]);
743            assert_eq!(right.channel(0), &[10.0, 11.0, 12.0, 13.0, 14.0]);
744        }
745        {
746            let image = image.as_ref();
747            let (slice1, slice2) = image.slice(2..4).split_channels(2);
748            assert_eq!(slice1.channel(0), &[2.0, 3.0]);
749            assert_eq!(slice1.channel(1), &[7.0, 8.0]);
750            assert_eq!(slice2.channel(0), &[12.0, 13.0]);
751        }
752    }
753}