Skip to main content

lutgen_studio/
worker.rs

1use std::fmt::Display;
2use std::hash::{DefaultHasher, Hash, Hasher};
3use std::path::{Path, PathBuf};
4use std::sync::atomic::AtomicBool;
5use std::sync::mpsc::channel;
6use std::sync::Arc;
7
8use log::{debug, info};
9use lutgen::GenerateLut;
10use web_time::{Duration, Instant};
11
12use crate::color::Color;
13use crate::state::{
14    BlurArgs,
15    Common,
16    CommonRbf,
17    GaussianRbfArgs,
18    GaussianSamplingArgs,
19    ShepardsMethodArgs,
20};
21use crate::updates::UpdateInfo;
22
23#[derive(serde::Serialize, serde::Deserialize)]
24pub enum FrontendEvent {
25    LoadFile(PathBuf, #[cfg(target_arch = "wasm32")] Vec<u8>),
26    Apply(Vec<[u8; 3]>, Common, LutAlgorithmArgs, Arc<AtomicBool>),
27    SaveAs(
28        #[cfg(not(target_arch = "wasm32"))] PathBuf,
29        #[cfg(target_arch = "wasm32")] image::ImageFormat,
30    ),
31}
32
33#[derive(serde::Serialize, serde::Deserialize, Hash, Debug)]
34pub enum LutAlgorithmArgs {
35    GaussianRbf {
36        rbf: CommonRbf,
37        args: GaussianRbfArgs,
38    },
39    ShepardsMethod {
40        rbf: CommonRbf,
41        args: ShepardsMethodArgs,
42    },
43    GaussianSampling {
44        args: GaussianSamplingArgs,
45    },
46    GaussianBlur {
47        args: BlurArgs,
48    },
49    NearestNeighbor,
50}
51
52#[derive(serde::Serialize, serde::Deserialize)]
53pub enum BackendEvent {
54    Error(String),
55    Update(UpdateInfo),
56    SetImage {
57        time: Duration,
58        source: ImageSource,
59        image: Arc<[u8]>,
60        dim: (u32, u32),
61    },
62    #[cfg(target_arch = "wasm32")]
63    SaveData(Duration, String, image::ImageFormat),
64}
65
66impl Display for BackendEvent {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        match self {
69            BackendEvent::Error(e) => format!("Error: {e}").fmt(f),
70            BackendEvent::Update(_) => Ok(()),
71            BackendEvent::SetImage {
72                time,
73                dim: (x, y),
74                source: path,
75                ..
76            } => match path {
77                ImageSource::Image(_) => format!("Opened {x}x{y} image in {time:.2?}").fmt(f),
78                ImageSource::Edited(_) => {
79                    format!("Generated and applied LUT to image in {time:.2?}").fmt(f)
80                },
81            },
82            #[cfg(target_arch = "wasm32")]
83            BackendEvent::SaveData(time, _, format) => {
84                format!("Encoded {format:?} for download in {time:.2?}").fmt(f)
85            },
86        }
87    }
88}
89
90#[derive(serde::Serialize, serde::Deserialize)]
91pub enum ImageSource {
92    Image(PathBuf),
93    Edited(u64),
94}
95
96impl Display for ImageSource {
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98        match self {
99            ImageSource::Image(path_buf) => Display::fmt(&path_buf.display(), f),
100            ImageSource::Edited(hash) => Display::fmt(hash, f),
101        }
102    }
103}
104
105pub struct WorkerHandle {
106    #[cfg(target_arch = "wasm32")]
107    bridge: gloo_worker::WorkerBridge<Worker>,
108
109    #[cfg(not(target_arch = "wasm32"))]
110    tx: std::sync::mpsc::Sender<FrontendEvent>,
111
112    rx: std::sync::mpsc::Receiver<BackendEvent>,
113    abort: Arc<AtomicBool>,
114}
115
116impl WorkerHandle {
117    #[cfg(not(target_arch = "wasm32"))]
118    pub fn spawn(ctx: egui::Context) -> Self {
119        let (tx, worker_rx) = channel();
120        let (worker_tx, rx) = channel();
121        let abort = Arc::new(AtomicBool::new(false));
122
123        // Spawn thread to fetch the latest version and send it to the frontend if newer
124        let worker_tx_cloned = worker_tx.clone();
125        std::thread::spawn(move || {
126            if let Ok(Some(update)) = crate::updates::check_for_updates() {
127                worker_tx_cloned
128                    .send(BackendEvent::Update(update))
129                    .expect("failed to send update info to frontend");
130            }
131        });
132
133        std::thread::spawn(move || {
134            let mut worker = Worker {
135                current_image: None,
136                hasher: DefaultHasher::new(),
137                last_render: Default::default(),
138            };
139            while let Ok(event) = worker_rx.recv() {
140                if let Some(event) = worker.handle_event(event) {
141                    worker_tx
142                        .send(event)
143                        .expect("failed to send backend event to ui thread");
144                }
145                ctx.request_repaint();
146            }
147        });
148
149        WorkerHandle { tx, rx, abort }
150    }
151
152    #[cfg(target_arch = "wasm32")]
153    pub fn spawn(ctx: egui::Context) -> Self {
154        use gloo_worker::Spawnable;
155
156        let abort = Arc::new(AtomicBool::new(false));
157        let (tx, rx) = channel();
158        let bridge = Worker::spawner()
159            .callback(move |event| {
160                tx.send(event)
161                    .expect("failed to send backend event to worker handle");
162                ctx.request_repaint();
163            })
164            .spawn("worker.js");
165
166        Self { rx, bridge, abort }
167    }
168
169    fn send(&self, event: FrontendEvent) {
170        #[cfg(not(target_arch = "wasm32"))]
171        self.tx
172            .send(event)
173            .expect("failed to send save as request to worker");
174        #[cfg(target_arch = "wasm32")]
175        self.bridge.send(event);
176    }
177
178    pub fn save_as(
179        &self,
180        #[cfg(not(target_arch = "wasm32"))] item: PathBuf,
181        #[cfg(target_arch = "wasm32")] item: image::ImageFormat,
182    ) {
183        self.send(FrontendEvent::SaveAs(item));
184    }
185
186    pub fn load_file(&self, path: PathBuf, #[cfg(target_arch = "wasm32")] bytes: Vec<u8>) {
187        #[cfg(not(target_arch = "wasm32"))]
188        self.send(FrontendEvent::LoadFile(path));
189        #[cfg(target_arch = "wasm32")]
190        self.send(FrontendEvent::LoadFile(path, bytes));
191    }
192
193    pub fn apply_palette(&mut self, palette: Vec<[u8; 3]>, common: Common, args: LutAlgorithmArgs) {
194        // cancel previous run and init a new abort signal
195        self.abort.store(true, std::sync::atomic::Ordering::Relaxed);
196        self.abort = Arc::new(AtomicBool::new(false));
197
198        self.send(FrontendEvent::Apply(
199            palette,
200            common,
201            args,
202            self.abort.clone(),
203        ))
204    }
205
206    pub fn poll_event(&self) -> Option<BackendEvent> {
207        self.rx.try_recv().ok()
208    }
209}
210
211pub struct Worker {
212    current_image: Option<lutgen::RgbaImage>,
213    hasher: DefaultHasher,
214    last_render: Arc<[u8]>,
215}
216
217impl Worker {
218    fn handle_event(&mut self, event: FrontendEvent) -> Option<BackendEvent> {
219        let res = match event {
220            #[cfg(not(target_arch = "wasm32"))]
221            FrontendEvent::SaveAs(path) => self.save_as(path),
222            #[cfg(target_arch = "wasm32")]
223            FrontendEvent::SaveAs(format) => self.save_as(format),
224            #[cfg(not(target_arch = "wasm32"))]
225            FrontendEvent::LoadFile(path) => self.load_file(&path),
226            #[cfg(target_arch = "wasm32")]
227            FrontendEvent::LoadFile(path, bytes) => self.load_file(&path, bytes),
228            FrontendEvent::Apply(palette, common, args, abort) => {
229                self.apply_palette(palette, common, args, abort)
230            },
231        };
232        match res {
233            Ok(event) => event,
234            Err(e) => Some(BackendEvent::Error(e)),
235        }
236    }
237
238    fn save_as(
239        &self,
240        #[cfg(not(target_arch = "wasm32"))] path: PathBuf,
241        #[cfg(target_arch = "wasm32")] format: image::ImageFormat,
242    ) -> Result<Option<BackendEvent>, String> {
243        if self.last_render.is_empty() {
244            return Err("Image must be applied at least once".into());
245        }
246        if let Some(image) = &self.current_image {
247            #[cfg(not(target_arch = "wasm32"))]
248            if image::save_buffer(
249                &path,
250                &self.last_render,
251                image.width(),
252                image.height(),
253                image::ColorType::Rgba8,
254            )
255            .is_err()
256            {
257                // image format likely doesn't support transparency, convert to RGB
258                let buffer: Vec<u8> = self
259                    .last_render
260                    .chunks_exact(4)
261                    .flat_map(|v| &v[0..3])
262                    .cloned()
263                    .collect();
264                image::save_buffer(
265                    path,
266                    &buffer,
267                    image.width(),
268                    image.height(),
269                    image::ColorType::Rgb8,
270                )
271                .map_err(|e| format!("failed to encode image: {e}"))?;
272            }
273
274            #[cfg(target_arch = "wasm32")]
275            {
276                use base64::Engine;
277
278                let time = Instant::now();
279                let width = image.width();
280                let height = image.height();
281
282                info!("Encoding {width}x{height} image as {format:?}");
283
284                // encode image in the given format
285                let mut buf = std::io::Cursor::new(Vec::new());
286                if let Err(_) = image::write_buffer_with_format(
287                    &mut buf,
288                    &self.last_render,
289                    width,
290                    height,
291                    image::ColorType::Rgba8,
292                    format,
293                ) {
294                    // image format likely doesn't support transparency
295                    let buffer: Vec<u8> = self
296                        .last_render
297                        .chunks_exact(4)
298                        .flat_map(|v| &v[0..3])
299                        .cloned()
300                        .collect();
301                    buf = std::io::Cursor::new(Vec::new());
302                    image::write_buffer_with_format(
303                        &mut buf,
304                        &buffer,
305                        width,
306                        height,
307                        image::ColorType::Rgb8,
308                        format,
309                    )
310                    .map_err(|e| format!("failed to encode image: {e}"))?;
311                }
312
313                // encode image file as data url and send to frontend
314                let data = base64::engine::general_purpose::STANDARD.encode(&buf.into_inner());
315                return Ok(Some(BackendEvent::SaveData(
316                    time.elapsed(),
317                    format!("data:{};base64,{data}", format.to_mime_type()),
318                    format,
319                )));
320            }
321        }
322
323        Ok(None)
324    }
325
326    fn load_file(
327        &mut self,
328        path: &Path,
329        #[cfg(target_arch = "wasm32")] bytes: Vec<u8>,
330    ) -> Result<Option<BackendEvent>, String> {
331        let time = Instant::now();
332
333        #[cfg(not(target_arch = "wasm32"))]
334        let image = image::open(path);
335        #[cfg(target_arch = "wasm32")]
336        let image = image::load_from_memory(&bytes);
337        let image = image.map_err(|e| e.to_string())?.to_rgba8();
338
339        // hash image
340        self.hasher = DefaultHasher::new();
341        image.hash(&mut self.hasher);
342
343        let frame = image.to_vec().into();
344        let dim = (image.height(), image.width());
345        self.current_image = Some(image);
346
347        Ok(Some(BackendEvent::SetImage {
348            time: time.elapsed(),
349            source: ImageSource::Image(path.to_path_buf()),
350            image: frame,
351            dim,
352        }))
353    }
354
355    /// Apply a palette to the currently loaded image
356    fn apply_palette(
357        &mut self,
358        palette: Vec<[u8; 3]>,
359        common: Common,
360        args: LutAlgorithmArgs,
361        abort: Arc<AtomicBool>,
362    ) -> Result<Option<BackendEvent>, String> {
363        let time = Instant::now();
364
365        let Some(mut image) = self.current_image.clone() else {
366            // do nothing if no image is loaded
367            return Ok(None);
368        };
369
370        // hash arguments with existing image hash
371        let mut hasher = self.hasher.clone();
372        palette.hash(&mut hasher);
373        common.hash(&mut hasher);
374        args.hash(&mut hasher);
375        let hash = hasher.finish();
376
377        info!("Generating LUT with args:\n{common:?}\n{args:?}");
378        debug!(
379            "LUT input palette ({} colors):\n{}",
380            palette.len(),
381            palette
382                .chunks(5)
383                .map(|v| v
384                    .iter()
385                    .cloned()
386                    .map(|v| Color(v).to_string())
387                    .collect::<Vec<_>>()
388                    .join(", "))
389                .collect::<Vec<_>>()
390                .join("\n")
391        );
392
393        // generate lut from arguments
394        let lut = match args {
395            LutAlgorithmArgs::GaussianRbf { rbf, args } => {
396                lutgen::interpolation::GaussianRemapper::new(
397                    &palette,
398                    *args.shape,
399                    rbf.nearest,
400                    *common.lum_factor,
401                    common.preserve,
402                )
403                .par_generate_lut_with_interrupt(common.level, abort)
404            },
405            LutAlgorithmArgs::ShepardsMethod { rbf, args } => {
406                lutgen::interpolation::ShepardRemapper::new(
407                    &palette,
408                    *args.power,
409                    rbf.nearest,
410                    *common.lum_factor,
411                    common.preserve,
412                )
413                .par_generate_lut_with_interrupt(common.level, abort)
414            },
415            LutAlgorithmArgs::GaussianSampling { args } => {
416                lutgen::interpolation::GaussianSamplingRemapper::new(
417                    &palette,
418                    *args.mean,
419                    *args.std_dev,
420                    args.iterations,
421                    *common.lum_factor,
422                    args.seed,
423                    common.preserve,
424                )
425                .par_generate_lut_with_interrupt(common.level, abort)
426            },
427            LutAlgorithmArgs::GaussianBlur { args } => {
428                lutgen::interpolation::GaussianBlurRemapper::new(
429                    &palette,
430                    *args.radius,
431                    *common.lum_factor,
432                    common.preserve,
433                )
434                .par_generate_lut_with_interrupt(common.level, abort)
435            },
436            LutAlgorithmArgs::NearestNeighbor => {
437                lutgen::interpolation::NearestNeighborRemapper::new(
438                    &palette,
439                    *common.lum_factor,
440                    common.preserve,
441                )
442                .par_generate_lut_with_interrupt(common.level, abort)
443            },
444        }
445        .ok_or("Cancelled generating hald clut".to_string())?;
446
447        // remap image
448        lutgen::identity::correct_image_with_level(&mut image, &lut, common.level);
449        self.last_render = image.to_vec().into();
450
451        Ok(Some(BackendEvent::SetImage {
452            time: time.elapsed(),
453            source: ImageSource::Edited(hash),
454            image: self.last_render.clone(),
455            dim: (image.height(), image.width()),
456        }))
457    }
458}
459
460#[cfg(target_arch = "wasm32")]
461impl gloo_worker::Worker for Worker {
462    type Input = FrontendEvent;
463    type Output = BackendEvent;
464    type Message = ();
465
466    fn create(_scope: &gloo_worker::WorkerScope<Self>) -> Self {
467        Worker {
468            current_image: None,
469            hasher: DefaultHasher::new(),
470            last_render: Default::default(),
471        }
472    }
473
474    fn received(
475        &mut self,
476        scope: &gloo_worker::WorkerScope<Self>,
477        msg: FrontendEvent,
478        id: gloo_worker::HandlerId,
479    ) {
480        if let Some(event) = self.handle_event(msg) {
481            scope.respond(id, event);
482        }
483    }
484
485    fn update(&mut self, _scope: &gloo_worker::WorkerScope<Self>, _msg: Self::Message) {}
486}