1use image::{ImageBuffer, Luma, Pixel};
23
24use std::fmt;
25
26pub trait PixelTransform<Pix: Pixel> {
32 type Output: Pixel + 'static;
34
35 fn transform_pixel(&self, pixel: Pix) -> Self::Output;
37}
38
39impl<Pix: Pixel + 'static> PixelTransform<Pix> for () {
41 type Output = Pix;
42
43 #[inline]
44 fn transform_pixel(&self, pixel: Pix) -> Self::Output {
45 pixel
46 }
47}
48
49impl<Pix, O> PixelTransform<Pix> for Box<dyn PixelTransform<Pix, Output = O>>
50where
51 Pix: Pixel,
52 O: Pixel + 'static,
53{
54 type Output = O;
55
56 #[inline]
57 fn transform_pixel(&self, pixel: Pix) -> Self::Output {
58 (**self).transform_pixel(pixel)
59 }
60}
61
62impl<F, G, Pix: Pixel> PixelTransform<Pix> for (F, G)
65where
66 F: PixelTransform<Pix>,
67 G: PixelTransform<F::Output>,
68{
69 type Output = G::Output;
70
71 #[inline]
72 fn transform_pixel(&self, pixel: Pix) -> Self::Output {
73 self.1.transform_pixel(self.0.transform_pixel(pixel))
74 }
75}
76
77#[derive(Debug, Clone, Copy, Default)]
79pub struct Negative;
80
81impl PixelTransform<Luma<u8>> for Negative {
82 type Output = Luma<u8>;
83
84 #[inline]
85 fn transform_pixel(&self, pixel: Luma<u8>) -> Self::Output {
86 Luma([u8::max_value() - pixel[0]])
87 }
88}
89
90#[derive(Debug, Clone, Copy, Default)]
93pub struct Smoothstep;
94
95impl PixelTransform<Luma<u8>> for Smoothstep {
96 type Output = Luma<u8>;
97
98 #[inline]
99 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
100 fn transform_pixel(&self, pixel: Luma<u8>) -> Self::Output {
101 let clamped_x = f32::from(pixel[0]) / 255.0;
102 let output = clamped_x * clamped_x * (3.0 - 2.0 * clamped_x);
103 Luma([(output * 255.0).round() as u8])
104 }
105}
106
107#[derive(Clone)]
109pub struct Palette<T> {
110 pixels: [T; 256],
111}
112
113impl<T: fmt::Debug> fmt::Debug for Palette<T> {
114 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
115 formatter
116 .debug_struct("Palette")
117 .field("pixels", &(&self.pixels as &[T]))
118 .finish()
119 }
120}
121
122impl<T> Palette<T>
123where
124 T: Pixel<Subpixel = u8> + 'static,
125{
126 #[allow(
134 clippy::cast_precision_loss,
135 clippy::cast_possible_truncation,
136 clippy::cast_sign_loss
137 )]
138 pub fn new(colors: &[T]) -> Self {
139 assert!(colors.len() >= 2, "palette must contain at least 2 colors");
140 assert!(
141 colors.len() <= 256,
142 "palette cannot contain more than 256 colors"
143 );
144 let len_scale = (colors.len() - 1) as f32;
145
146 let mut pixels = [T::from_channels(0, 0, 0, 0); 256];
147 for (i, pixel) in pixels.iter_mut().enumerate() {
148 let float_i = i as f32 / 255.0 * len_scale;
149
150 let mut prev_color_idx = float_i as usize; if prev_color_idx == colors.len() - 1 {
152 prev_color_idx -= 1;
153 }
154 debug_assert!(prev_color_idx + 1 < colors.len());
155
156 let prev_color = colors[prev_color_idx].channels();
157 let next_color = colors[prev_color_idx + 1].channels();
158 let blend_factor = float_i - prev_color_idx as f32;
159 debug_assert!((0.0..=1.0).contains(&blend_factor));
160
161 let mut blended_channels = [0_u8; 4];
162 let channel_count = T::CHANNEL_COUNT as usize;
163 for (ch, blended_channel) in blended_channels[..channel_count].iter_mut().enumerate() {
164 let blended = f32::from(prev_color[ch]) * (1.0 - blend_factor)
165 + f32::from(next_color[ch]) * blend_factor;
166 *blended_channel = blended.round() as u8;
167 }
168 *pixel = *T::from_slice(&blended_channels[..channel_count]);
169 }
170
171 Self { pixels }
172 }
173}
174
175impl<Pix: Pixel + 'static> PixelTransform<Luma<u8>> for Palette<Pix> {
176 type Output = Pix;
177
178 #[inline]
179 fn transform_pixel(&self, pixel: Luma<u8>) -> Self::Output {
180 self.pixels[pixel[0] as usize]
181 }
182}
183
184pub trait ApplyTransform<Pix: Pixel, F> {
189 type CombinedTransform: PixelTransform<Pix>;
191 fn apply(self, transform: F) -> ImageAndTransform<Pix, Self::CombinedTransform>;
193}
194
195#[derive(Debug)]
198pub struct ImageAndTransform<Pix, F>
199where
200 Pix: Pixel,
201{
202 source_image: ImageBuffer<Pix, Vec<Pix::Subpixel>>,
203 transform: F,
204}
205
206impl<Pix, F> ImageAndTransform<Pix, F>
207where
208 Pix: Pixel + Copy + 'static,
209 F: PixelTransform<Pix>,
210 <F::Output as Pixel>::Subpixel: 'static,
211{
212 pub fn transform(&self) -> ImageBuffer<F::Output, Vec<<F::Output as Pixel>::Subpixel>> {
214 let mut output = ImageBuffer::new(self.source_image.width(), self.source_image.height());
215
216 let output_iter = self
217 .source_image
218 .enumerate_pixels()
219 .map(|(x, y, pixel)| (x, y, self.transform.transform_pixel(*pixel)));
220 for (x, y, out_pixel) in output_iter {
221 output[(x, y)] = out_pixel;
222 }
223 output
224 }
225}
226
227impl<Pix, F> ApplyTransform<Pix, F> for ImageBuffer<Pix, Vec<Pix::Subpixel>>
228where
229 Pix: Pixel,
230 F: PixelTransform<Pix>,
231{
232 type CombinedTransform = F;
233
234 fn apply(self, transform: F) -> ImageAndTransform<Pix, F> {
235 ImageAndTransform {
236 source_image: self,
237 transform,
238 }
239 }
240}
241
242impl<Pix, F, G> ApplyTransform<Pix, G> for ImageAndTransform<Pix, F>
243where
244 Pix: Pixel,
245 F: PixelTransform<Pix>,
246 G: PixelTransform<F::Output>,
247{
248 type CombinedTransform = (F, G);
249
250 fn apply(self, transform: G) -> ImageAndTransform<Pix, (F, G)> {
251 ImageAndTransform {
252 source_image: self.source_image,
253 transform: (self.transform, transform),
254 }
255 }
256}
257
258#[cfg(test)]
259#[allow(
260 clippy::cast_possible_truncation,
261 clippy::cast_precision_loss,
262 clippy::cast_sign_loss
263)]
264mod tests {
265 use super::*;
266 use image::{GrayImage, Rgb};
267
268 #[test]
269 fn simple_transform() {
270 let image = GrayImage::from_fn(100, 100, |x, y| Luma::from([(x + y) as u8]));
271 let image = image.apply(Negative).apply(Smoothstep).transform();
272 for (x, y, pix) in image.enumerate_pixels() {
273 let negated = (255 - x - y) as f32 / 255.0;
274 let smoothed = negated * negated * (3.0 - 2.0 * negated);
275 let expected_pixel = (smoothed * 255.0).round() as u8;
276 assert_eq!(pix[0], expected_pixel);
277 }
278 }
279
280 #[test]
281 fn palette_basics() {
282 let palette = Palette::new(&[Rgb([0, 255, 0]), Rgb([255, 255, 255])]);
283 for (i, &pixel) in palette.pixels.iter().enumerate() {
284 assert_eq!(pixel, Rgb([i as u8, 255, i as u8]));
285 }
286 }
287}