use crate::{ChromaHash, Gamut};
use std::num::NonZeroUsize;
use std::sync::mpsc::{self, Sender};
use std::sync::{Arc, Mutex};
use std::thread::{self, JoinHandle};
#[derive(Debug, Clone)]
pub struct ImageInput {
pub w: u32,
pub h: u32,
pub rgba: Arc<[u8]>,
pub gamut: Gamut,
}
struct Job {
index: usize,
input: ImageInput,
result_tx: Sender<(usize, ChromaHash)>,
}
pub struct BatchEncoder {
job_tx: Option<Sender<Job>>,
workers: Vec<JoinHandle<()>>,
}
impl BatchEncoder {
#[must_use]
pub fn new() -> Self {
let threads = thread::available_parallelism()
.map(NonZeroUsize::get)
.unwrap_or(1);
Self::with_threads(threads)
}
#[must_use]
pub fn with_threads(n: usize) -> Self {
let n = n.max(1);
let (job_tx, job_rx) = mpsc::channel::<Job>();
let job_rx = Arc::new(Mutex::new(job_rx));
let mut workers = Vec::with_capacity(n);
for _ in 0..n {
let job_rx = Arc::clone(&job_rx);
workers.push(thread::spawn(move || {
loop {
let message = job_rx.lock().unwrap().recv();
match message {
Ok(job) => {
let hash = ChromaHash::encode(
job.input.w,
job.input.h,
&job.input.rgba,
job.input.gamut,
);
let _ = job.result_tx.send((job.index, hash));
}
Err(_) => break,
}
}
}));
}
Self {
job_tx: Some(job_tx),
workers,
}
}
#[must_use]
pub fn encode_batch(&self, items: &[ImageInput]) -> Vec<ChromaHash> {
for (i, item) in items.iter().enumerate() {
assert!(item.w >= 1, "item {i}: width must be >= 1");
assert!(item.h >= 1, "item {i}: height must be >= 1");
assert!(
item.rgba.len() == (item.w as usize) * (item.h as usize) * 4,
"item {i}: rgba length mismatch"
);
}
if items.is_empty() {
return Vec::new();
}
let job_tx = self
.job_tx
.as_ref()
.expect("BatchEncoder job channel is open while alive");
let (result_tx, result_rx) = mpsc::channel::<(usize, ChromaHash)>();
for (index, item) in items.iter().enumerate() {
job_tx
.send(Job {
index,
input: item.clone(),
result_tx: result_tx.clone(),
})
.expect("worker pool is running");
}
drop(result_tx);
let mut out = vec![ChromaHash::from_bytes([0u8; 32]); items.len()];
for _ in 0..items.len() {
let (index, hash) = result_rx.recv().expect("every job reports a result");
out[index] = hash;
}
out
}
}
impl Default for BatchEncoder {
fn default() -> Self {
Self::new()
}
}
impl Drop for BatchEncoder {
fn drop(&mut self) {
self.job_tx = None;
for worker in self.workers.drain(..) {
let _ = worker.join();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn solid_image(w: u32, h: u32, r: u8, g: u8, b: u8, a: u8) -> Arc<[u8]> {
let pixel_count = (w * h) as usize;
let mut rgba = vec![0u8; pixel_count * 4];
for i in 0..pixel_count {
rgba[i * 4] = r;
rgba[i * 4 + 1] = g;
rgba[i * 4 + 2] = b;
rgba[i * 4 + 3] = a;
}
Arc::from(rgba)
}
fn horizontal_gradient(w: u32, h: u32) -> Arc<[u8]> {
let mut rgba = vec![0u8; (w * h * 4) as usize];
for y in 0..h {
for x in 0..w {
let t = x as f64 / (w - 1).max(1) as f64;
let idx = ((y * w + x) * 4) as usize;
rgba[idx] = (t * 255.0) as u8;
rgba[idx + 1] = ((1.0 - t) * 255.0) as u8;
rgba[idx + 2] = 128;
rgba[idx + 3] = 255;
}
}
Arc::from(rgba)
}
fn mixed_items() -> Vec<ImageInput> {
vec![
ImageInput {
w: 4,
h: 4,
rgba: solid_image(4, 4, 200, 100, 50, 255),
gamut: Gamut::Srgb,
},
ImageInput {
w: 8,
h: 4,
rgba: horizontal_gradient(8, 4),
gamut: Gamut::DisplayP3,
},
ImageInput {
w: 4,
h: 8,
rgba: solid_image(4, 8, 30, 200, 120, 128), gamut: Gamut::AdobeRgb,
},
ImageInput {
w: 16,
h: 16,
rgba: horizontal_gradient(16, 16),
gamut: Gamut::Bt2020,
},
ImageInput {
w: 1,
h: 1,
rgba: solid_image(1, 1, 255, 0, 0, 255),
gamut: Gamut::ProPhotoRgb,
},
]
}
fn encode_serial(items: &[ImageInput]) -> Vec<ChromaHash> {
items
.iter()
.map(|it| ChromaHash::encode(it.w, it.h, &it.rgba, it.gamut))
.collect()
}
#[test]
fn batch_matches_serial() {
let items = mixed_items();
let encoder = BatchEncoder::new();
let batch = encoder.encode_batch(&items);
let serial = encode_serial(&items);
assert_eq!(batch, serial);
}
#[test]
fn batch_preserves_order() {
let items: Vec<ImageInput> = (0..64)
.map(|i| ImageInput {
w: 8,
h: 8,
rgba: solid_image(8, 8, i as u8, (255 - i) as u8, (i * 3) as u8, 255),
gamut: Gamut::Srgb,
})
.collect();
let batch = BatchEncoder::new().encode_batch(&items);
assert_eq!(batch, encode_serial(&items));
}
#[test]
fn empty_batch_returns_empty() {
let encoder = BatchEncoder::new();
assert!(encoder.encode_batch(&[]).is_empty());
}
#[test]
fn single_thread_matches_default() {
let items = mixed_items();
let single = BatchEncoder::with_threads(1).encode_batch(&items);
let default = BatchEncoder::new().encode_batch(&items);
assert_eq!(single, default);
}
#[test]
fn encoder_reusable_across_batches() {
let encoder = BatchEncoder::new();
let items = mixed_items();
let first = encoder.encode_batch(&items);
let second = encoder.encode_batch(&items);
assert_eq!(first, second);
assert_eq!(first, encode_serial(&items));
}
#[test]
#[should_panic(expected = "item 1: rgba length mismatch")]
fn invalid_item_panics_with_index() {
let items = vec![
ImageInput {
w: 2,
h: 2,
rgba: solid_image(2, 2, 0, 0, 0, 255),
gamut: Gamut::Srgb,
},
ImageInput {
w: 2,
h: 2,
rgba: Arc::from(vec![0u8; 3]), gamut: Gamut::Srgb,
},
];
let _ = BatchEncoder::new().encode_batch(&items);
}
}