kmeans_color_gpu/
lib.rs

1use anyhow::{anyhow, Result};
2use palette::{IntoColor, Lab, Srgba};
3pub use rgb::RGBA8;
4use std::sync::Arc;
5use std::{fmt::Display, str::FromStr};
6use wgpu::{
7    Device, DeviceDescriptor, Features, Instance, PowerPreference, Queue, RequestAdapterOptionsBase,
8};
9
10use crate::image::{Container, Image};
11use crate::structures::{CentroidsBuffer, InputTexture};
12
13mod future;
14mod modules;
15mod octree;
16mod operations;
17#[cfg(test)]
18mod shader_tests;
19mod structures;
20mod utils;
21
22pub mod image;
23
24pub struct ImageProcessor {
25    device: Arc<Device>,
26    queue: Arc<Queue>,
27}
28
29impl ImageProcessor {
30    /// Create a new ImageProcessor, initializing a [wgpu::Device] and [wgpu::Queue]
31    /// to use in future operations.
32    /// ```rust,no_run
33    /// use pollster::FutureExt;
34    /// use kmeans_color_gpu::ImageProcessor;
35    ///
36    /// let image_processor = ImageProcessor::new().block_on();
37    /// ```
38    pub async fn new() -> Result<Self> {
39        let instance = Instance::default();
40        let adapter = instance
41            .request_adapter(&RequestAdapterOptionsBase {
42                power_preference: PowerPreference::HighPerformance,
43                force_fallback_adapter: false,
44                compatible_surface: None,
45            })
46            .await
47            .ok_or_else(|| anyhow::anyhow!("Couldn't create the adapter"))?;
48
49        let features = adapter.features();
50        let (device, queue) = adapter
51            .request_device(
52                &DeviceDescriptor {
53                    label: None,
54                    features: features & (Features::TIMESTAMP_QUERY),
55                    limits: Default::default(),
56                },
57                None,
58            )
59            .await?;
60
61        Ok(Self {
62            device: Arc::new(device),
63            queue: Arc::new(queue),
64        })
65    }
66
67    pub async fn palette<C: Container>(
68        &self,
69        color_count: u32,
70        image: &Image<C>,
71        algo: Algorithm,
72    ) -> Result<Vec<RGBA8>> {
73        match algo {
74            Algorithm::Kmeans => kmeans_palette(self, color_count, image).await,
75            Algorithm::Octree => octree_palette(self, color_count, image).await,
76        }
77    }
78
79    pub async fn find<C: Container>(
80        &self,
81        image: &Image<C>,
82        colors: &[RGBA8],
83        reduce_mode: &ReduceMode,
84    ) -> Result<Image<Vec<RGBA8>>> {
85        let input_texture = InputTexture::new(&self.device, &self.queue, image);
86        let centroids_buffer =
87            CentroidsBuffer::fixed_centroids(colors, &ColorSpace::Lab, &self.device);
88
89        match reduce_mode {
90            ReduceMode::Replace => operations::find_colors(
91                &self.device,
92                &self.queue,
93                &input_texture,
94                &ColorSpace::Lab,
95                &centroids_buffer,
96            ),
97            ReduceMode::Dither => operations::dither_colors(
98                &self.device,
99                &self.queue,
100                &input_texture,
101                &ColorSpace::Lab,
102                &centroids_buffer,
103            ),
104            ReduceMode::Meld => operations::meld_colors(
105                &self.device,
106                &self.queue,
107                &input_texture,
108                &ColorSpace::Lab,
109                &centroids_buffer,
110            ),
111        }?
112        .pull_image(&self.device, &self.queue)
113        .await
114    }
115
116    pub async fn reduce<C: Container>(
117        &self,
118        color_count: u32,
119        image: &Image<C>,
120        algo: &Algorithm,
121        reduce_mode: &ReduceMode,
122    ) -> Result<Image<Vec<RGBA8>>> {
123        let input_texture = InputTexture::new(&self.device, &self.queue, image);
124
125        let centroids_buffer = match algo {
126            Algorithm::Kmeans => operations::extract_palette_kmeans(
127                &self.device,
128                &self.queue,
129                &input_texture,
130                &ColorSpace::Lab,
131                color_count,
132            )?,
133            Algorithm::Octree => {
134                let palette = octree_palette(self, color_count, image).await?;
135                CentroidsBuffer::fixed_centroids(&palette, &ColorSpace::Lab, &self.device)
136            }
137        };
138
139        let output_texture = match reduce_mode {
140            ReduceMode::Replace => operations::find_colors(
141                &self.device,
142                &self.queue,
143                &input_texture,
144                &ColorSpace::Lab,
145                &centroids_buffer,
146            ),
147            ReduceMode::Dither => operations::dither_colors(
148                &self.device,
149                &self.queue,
150                &input_texture,
151                &ColorSpace::Lab,
152                &centroids_buffer,
153            ),
154            ReduceMode::Meld => operations::meld_colors(
155                &self.device,
156                &self.queue,
157                &input_texture,
158                &ColorSpace::Lab,
159                &centroids_buffer,
160            ),
161        }?;
162
163        output_texture.pull_image(&self.device, &self.queue).await
164    }
165}
166
167#[derive(Clone, Copy)]
168pub enum ColorSpace {
169    Lab,
170    Rgb,
171}
172
173impl ColorSpace {
174    pub fn from(str: &str) -> Option<ColorSpace> {
175        match str {
176            "lab" => Some(ColorSpace::Lab),
177            "rgb" => Some(ColorSpace::Rgb),
178            _ => None,
179        }
180    }
181
182    pub fn name(&self) -> &'static str {
183        match self {
184            ColorSpace::Lab => "lab",
185            ColorSpace::Rgb => "rgb",
186        }
187    }
188
189    pub fn convergence(&self) -> f32 {
190        match self {
191            ColorSpace::Lab => 1.0,
192            ColorSpace::Rgb => 0.01,
193        }
194    }
195}
196
197impl FromStr for ColorSpace {
198    type Err = anyhow::Error;
199
200    fn from_str(s: &str) -> Result<Self, Self::Err> {
201        match s {
202            "lab" => Ok(ColorSpace::Lab),
203            "rgb" => Ok(ColorSpace::Rgb),
204            _ => Err(anyhow!("Unsupported color space {s}")),
205        }
206    }
207}
208
209impl Display for ColorSpace {
210    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211        write!(f, "{}", self.name())
212    }
213}
214
215#[derive(Clone, Copy)]
216pub enum Algorithm {
217    Kmeans,
218    Octree,
219}
220
221impl Display for Algorithm {
222    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223        write!(
224            f,
225            "{}",
226            match self {
227                Algorithm::Kmeans => "kmeans",
228                Algorithm::Octree => "octree",
229            }
230        )
231    }
232}
233
234#[derive(Clone, Copy)]
235pub enum ReduceMode {
236    Replace,
237    Dither,
238    Meld,
239}
240
241impl Display for ReduceMode {
242    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243        write!(
244            f,
245            "{}",
246            match self {
247                ReduceMode::Replace => "replace",
248                ReduceMode::Dither => "dither",
249                ReduceMode::Meld => "meld",
250            }
251        )
252    }
253}
254
255async fn kmeans_palette<C: Container>(
256    image_processor: &ImageProcessor,
257    color_count: u32,
258    image: &Image<C>,
259) -> Result<Vec<RGBA8>> {
260    let input_texture = InputTexture::new(&image_processor.device, &image_processor.queue, image);
261
262    let mut colors = operations::extract_palette_kmeans(
263        &image_processor.device,
264        &image_processor.queue,
265        &input_texture,
266        &ColorSpace::Lab,
267        color_count,
268    )?
269    .pull_values(
270        &image_processor.device,
271        &image_processor.queue,
272        &ColorSpace::Lab,
273    )
274    .await?;
275
276    colors.sort_unstable_by(|a, b| {
277        let a: Lab = Srgba::new(a.r, a.g, a.b, a.a)
278            .into_format::<_, f32>()
279            .into_color();
280        let b: Lab = Srgba::new(b.r, b.g, b.b, b.a)
281            .into_format::<_, f32>()
282            .into_color();
283        a.l.partial_cmp(&b.l).unwrap()
284    });
285    Ok(colors)
286}
287
288async fn octree_palette<C: Container>(
289    image_processor: &ImageProcessor,
290    color_count: u32,
291    image: &Image<C>,
292) -> Result<Vec<RGBA8>> {
293    const MAX_SIZE: u32 = 128;
294
295    let (width, height) = image.dimensions;
296    let resized = if width > MAX_SIZE || height > MAX_SIZE {
297        let input_texture = InputTexture::new(
298            &image_processor.device,
299            &image_processor.queue,
300            image,
301        )
302        .resized(MAX_SIZE, &image_processor.device, &image_processor.queue);
303        Some(
304            input_texture
305                .pull_image(&image_processor.device, &image_processor.queue)
306                .await?,
307        )
308    } else {
309        None
310    };
311
312    let pixels: &[RGBA8] = if let Some(resized) = &resized {
313        &resized.rgba
314    } else {
315        &image.rgba
316    };
317
318    let mut colors = operations::extract_palette_octree(pixels, color_count)?;
319
320    colors.sort_unstable_by(|a, b| {
321        let a: Lab = Srgba::new(a.r, a.g, a.b, a.a)
322            .into_format::<_, f32>()
323            .into_color();
324        let b: Lab = Srgba::new(b.r, b.g, b.b, b.a)
325            .into_format::<_, f32>()
326            .into_color();
327        a.l.partial_cmp(&b.l).unwrap()
328    });
329
330    Ok(colors)
331}