1use crate::err::MismatchedSize;
30use crate::ForgeError;
31use std::fmt::Debug;
32
33#[derive(Debug)]
34pub enum BufferStore<'a, T: Copy + Debug> {
35 Borrowed(&'a mut [T]),
36 Owned(Vec<T>),
37}
38
39impl<T: Copy + Debug> BufferStore<'_, T> {
40 #[allow(clippy::should_implement_trait)]
41 pub fn borrow(&self) -> &[T] {
42 match self {
43 Self::Borrowed(p_ref) => p_ref,
44 Self::Owned(vec) => vec,
45 }
46 }
47
48 #[allow(clippy::should_implement_trait)]
49 pub fn borrow_mut(&mut self) -> &mut [T] {
50 match self {
51 Self::Borrowed(p_ref) => p_ref,
52 Self::Owned(vec) => vec,
53 }
54 }
55}
56
57pub struct GainImage<'a, T: Clone + Copy + Default + Debug, const N: usize> {
59 pub data: std::borrow::Cow<'a, [T]>,
60 pub width: usize,
61 pub height: usize,
62 pub stride: usize,
64}
65
66pub struct GainImageMut<'a, T: Clone + Copy + Default + Debug, const N: usize> {
68 pub data: BufferStore<'a, T>,
69 pub width: usize,
70 pub height: usize,
71 pub stride: usize,
73}
74
75impl<'a, T: Clone + Copy + Default + Debug, const N: usize> GainImage<'a, T, N> {
76 pub fn alloc(width: usize, height: usize) -> Self {
78 Self {
79 data: std::borrow::Cow::Owned(vec![T::default(); width * height * N]),
80 width,
81 height,
82 stride: width * N,
83 }
84 }
85
86 pub fn borrow(arr: &'a [T], width: usize, height: usize) -> Self {
89 Self {
90 data: std::borrow::Cow::Borrowed(arr),
91 width,
92 height,
93 stride: width * N,
94 }
95 }
96
97 #[inline]
99 pub fn size_matches(&self, other: &GainImage<'_, T, N>) -> Result<(), ForgeError> {
100 if self.width == other.width && self.height == other.height {
101 return Ok(());
102 }
103 Err(ForgeError::ImageSizeMismatch)
104 }
105
106 #[inline]
108 pub fn size_matches_arb<const G: usize>(
109 &self,
110 other: &GainImage<'_, T, G>,
111 ) -> Result<(), ForgeError> {
112 if self.width == other.width && self.height == other.height {
113 return Ok(());
114 }
115 Err(ForgeError::ImageSizeMismatch)
116 }
117
118 #[inline]
120 pub fn size_matches_mut(&self, other: &GainImageMut<'_, T, N>) -> Result<(), ForgeError> {
121 if self.width == other.width && self.height == other.height {
122 return Ok(());
123 }
124 Err(ForgeError::ImageSizeMismatch)
125 }
126
127 #[inline]
129 pub fn check_layout_channels(&self, cn: usize) -> Result<(), ForgeError> {
130 if self.width == 0 || self.height == 0 {
131 return Err(ForgeError::ZeroBaseSize);
132 }
133 let data_len = self.data.as_ref().len();
134 if data_len < self.stride * (self.height - 1) + self.width * cn {
135 return Err(ForgeError::MinimumSliceSizeMismatch(MismatchedSize {
136 expected: self.stride * self.height,
137 received: data_len,
138 }));
139 }
140 if (self.stride) < (self.width * cn) {
141 return Err(ForgeError::MinimumStrideSizeMismatch(MismatchedSize {
142 expected: self.width * cn,
143 received: self.stride,
144 }));
145 }
146 Ok(())
147 }
148
149 #[inline]
151 pub fn row_stride(&self) -> usize {
152 if self.stride == 0 {
153 self.width * N
154 } else {
155 self.stride
156 }
157 }
158
159 #[inline]
160 pub fn check_layout(&self) -> Result<(), ForgeError> {
161 if self.width == 0 || self.height == 0 {
162 return Err(ForgeError::ZeroBaseSize);
163 }
164 let cn = N;
165 if self.data.len() < self.stride * (self.height - 1) + self.width * cn {
166 return Err(ForgeError::MinimumSliceSizeMismatch(MismatchedSize {
167 expected: self.stride * self.height,
168 received: self.data.len(),
169 }));
170 }
171 if (self.stride) < (self.width * cn) {
172 return Err(ForgeError::MinimumStrideSizeMismatch(MismatchedSize {
173 expected: self.width * cn,
174 received: self.stride,
175 }));
176 }
177 Ok(())
178 }
179}
180
181impl<'a, T: Clone + Copy + Default + Debug, const N: usize> GainImageMut<'a, T, N> {
182 pub fn alloc(width: usize, height: usize) -> Self {
184 Self {
185 data: BufferStore::Owned(vec![T::default(); width * height * N]),
186 width,
187 height,
188 stride: width * N,
189 }
190 }
191
192 pub fn borrow(arr: &'a mut [T], width: usize, height: usize) -> Self {
195 Self {
196 data: BufferStore::Borrowed(arr),
197 width,
198 height,
199 stride: width * N,
200 }
201 }
202
203 #[inline]
205 pub fn row_stride(&self) -> usize {
206 if self.stride == 0 {
207 self.width * N
208 } else {
209 self.stride
210 }
211 }
212
213 #[inline]
215 pub fn check_layout(&self) -> Result<(), ForgeError> {
216 if self.width == 0 || self.height == 0 {
217 return Err(ForgeError::ZeroBaseSize);
218 }
219 let data_len = self.data.borrow().len();
220 if data_len < self.stride * (self.height - 1) + self.width * N {
221 return Err(ForgeError::MinimumSliceSizeMismatch(MismatchedSize {
222 expected: self.stride * self.height,
223 received: data_len,
224 }));
225 }
226 if (self.stride) < (self.width * N) {
227 return Err(ForgeError::MinimumStrideSizeMismatch(MismatchedSize {
228 expected: self.width * N,
229 received: self.stride,
230 }));
231 }
232 Ok(())
233 }
234
235 #[inline]
237 pub fn size_matches(&self, other: &GainImage<'_, T, N>) -> Result<(), ForgeError> {
238 if self.width == other.width && self.height == other.height {
239 return Ok(());
240 }
241 Err(ForgeError::ImageSizeMismatch)
242 }
243
244 #[inline]
246 pub fn size_matches_mut(&self, other: &GainImageMut<'_, T, N>) -> Result<(), ForgeError> {
247 if self.width == other.width && self.height == other.height {
248 return Ok(());
249 }
250 Err(ForgeError::ImageSizeMismatch)
251 }
252
253 pub fn to_immutable_ref(&self) -> GainImage<'_, T, N> {
254 GainImage {
255 data: std::borrow::Cow::Borrowed(self.data.borrow()),
256 stride: self.row_stride(),
257 width: self.width,
258 height: self.height,
259 }
260 }
261}
262
263impl<const N: usize> GainImage<'_, u8, N> {
264 pub fn expand_to_u16(&self, target_bit_depth: usize) -> GainImageMut<'_, u16, N> {
265 assert!(target_bit_depth >= 8 || target_bit_depth <= 16);
266 let mut new_image = GainImageMut::<u16, N>::alloc(self.width, self.height);
267 let dst_stride = new_image.row_stride();
268 let shift_left = target_bit_depth.saturating_sub(8) as u16;
269 let shift_right = 8u16.saturating_sub(shift_left);
270 for (src, dst) in self
271 .data
272 .as_ref()
273 .chunks_exact(self.row_stride())
274 .zip(new_image.data.borrow_mut().chunks_exact_mut(dst_stride))
275 {
276 let src = &src[..self.width * N];
277 let dst = &mut dst[..self.width * N];
278 for (&src, dst) in src.iter().zip(dst.iter_mut()) {
279 *dst = ((src as u16) << shift_left) | ((src as u16) >> shift_right);
280 }
281 }
282 new_image
283 }
284}
285
286impl<const N: usize> GainImageMut<'_, u8, N> {
287 pub fn expand_to_u16(&self, target_bit_depth: usize) -> GainImageMut<'_, u16, N> {
288 assert!(target_bit_depth >= 8 || target_bit_depth <= 16);
289 let mut new_image = GainImageMut::<u16, N>::alloc(self.width, self.height);
290 let dst_stride = new_image.row_stride();
291 let shift_left = target_bit_depth.saturating_sub(8) as u16;
292 let shift_right = 8u16.saturating_sub(shift_left);
293 let width = self.width;
294 for (src, dst) in self
295 .data
296 .borrow()
297 .chunks_exact(self.row_stride())
298 .zip(new_image.data.borrow_mut().chunks_exact_mut(dst_stride))
299 {
300 let src = &src[..width * N];
301 let dst = &mut dst[..width * N];
302 for (&src, dst) in src.iter().zip(dst.iter_mut()) {
303 *dst = ((src as u16) << shift_left) | ((src as u16) >> shift_right);
304 }
305 }
306 new_image
307 }
308}