Skip to main content

codec_eval/eval/
session.rs

1//! Evaluation session with callback-based codec interface.
2//!
3//! This module provides [`EvalSession`], the main entry point for codec evaluation.
4//! External crates provide encode/decode callbacks, and the session handles
5//! metrics calculation, caching, and report generation.
6
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9use std::time::Instant;
10
11use imgref::ImgVec;
12use rgb::{RGB8, RGBA8};
13
14use crate::error::Result;
15use crate::eval::report::{CodecResult, CorpusReport, ImageReport};
16use crate::metrics::dssim::rgb8_to_dssim_image;
17use crate::metrics::{MetricConfig, MetricResult, calculate_psnr};
18use crate::viewing::ViewingCondition;
19
20/// Image data accepted by the evaluation session.
21///
22/// Supports both `imgref::ImgVec` types and raw slices for flexibility.
23/// For decoded images with ICC profiles, use the `WithIcc` variants.
24#[derive(Clone)]
25pub enum ImageData {
26    /// RGB8 image using imgref.
27    Rgb8(ImgVec<RGB8>),
28
29    /// RGBA8 image using imgref.
30    Rgba8(ImgVec<RGBA8>),
31
32    /// RGB8 raw slice with dimensions.
33    RgbSlice {
34        /// Pixel data in row-major order.
35        data: Vec<u8>,
36        /// Image width.
37        width: usize,
38        /// Image height.
39        height: usize,
40    },
41
42    /// RGBA8 raw slice with dimensions.
43    RgbaSlice {
44        /// Pixel data in row-major order.
45        data: Vec<u8>,
46        /// Image width.
47        width: usize,
48        /// Image height.
49        height: usize,
50    },
51
52    /// RGB8 raw slice with ICC profile for color management.
53    ///
54    /// Use this for decoded images that have embedded ICC profiles (e.g., XYB JPEGs).
55    /// The ICC profile will be used to transform pixels to sRGB before metric calculation.
56    RgbSliceWithIcc {
57        /// Pixel data in row-major order.
58        data: Vec<u8>,
59        /// Image width.
60        width: usize,
61        /// Image height.
62        height: usize,
63        /// ICC profile data (raw bytes from the image).
64        icc_profile: Vec<u8>,
65    },
66}
67
68impl ImageData {
69    /// Get image width.
70    #[must_use]
71    pub fn width(&self) -> usize {
72        match self {
73            Self::Rgb8(img) => img.width(),
74            Self::Rgba8(img) => img.width(),
75            Self::RgbSlice { width, .. }
76            | Self::RgbaSlice { width, .. }
77            | Self::RgbSliceWithIcc { width, .. } => *width,
78        }
79    }
80
81    /// Get image height.
82    #[must_use]
83    pub fn height(&self) -> usize {
84        match self {
85            Self::Rgb8(img) => img.height(),
86            Self::Rgba8(img) => img.height(),
87            Self::RgbSlice { height, .. }
88            | Self::RgbaSlice { height, .. }
89            | Self::RgbSliceWithIcc { height, .. } => *height,
90        }
91    }
92
93    /// Convert to RGB8 slice representation.
94    ///
95    /// Note: This does NOT apply ICC profile transformation. For ICC-aware
96    /// conversion, use [`to_rgb8_srgb()`] instead.
97    #[must_use]
98    pub fn to_rgb8_vec(&self) -> Vec<u8> {
99        match self {
100            Self::Rgb8(img) => img.pixels().flat_map(|p| [p.r, p.g, p.b]).collect(),
101            Self::Rgba8(img) => img.pixels().flat_map(|p| [p.r, p.g, p.b]).collect(),
102            Self::RgbSlice { data, .. } | Self::RgbSliceWithIcc { data, .. } => data.clone(),
103            Self::RgbaSlice {
104                data,
105                width,
106                height,
107            } => {
108                let mut rgb = Vec::with_capacity(width * height * 3);
109                for chunk in data.chunks_exact(4) {
110                    rgb.push(chunk[0]);
111                    rgb.push(chunk[1]);
112                    rgb.push(chunk[2]);
113                }
114                rgb
115            }
116        }
117    }
118
119    /// Get the ICC profile if present.
120    #[must_use]
121    pub fn icc_profile(&self) -> Option<&[u8]> {
122        match self {
123            Self::RgbSliceWithIcc { icc_profile, .. } => Some(icc_profile),
124            _ => None,
125        }
126    }
127
128    /// Get the color profile for this image.
129    #[must_use]
130    pub fn color_profile(&self) -> crate::metrics::ColorProfile {
131        match self {
132            Self::RgbSliceWithIcc { icc_profile, .. } => {
133                crate::metrics::ColorProfile::Icc(icc_profile.clone())
134            }
135            _ => crate::metrics::ColorProfile::Srgb,
136        }
137    }
138
139    /// Convert to sRGB RGB8 slice, applying ICC profile transformation if needed.
140    ///
141    /// This is the ICC-aware version of [`to_rgb8_vec()`]. Use this when you need
142    /// the pixels in sRGB color space for metric calculation.
143    pub fn to_rgb8_srgb(&self) -> crate::error::Result<Vec<u8>> {
144        let rgb = self.to_rgb8_vec();
145        let profile = self.color_profile();
146        crate::metrics::transform_to_srgb(&rgb, &profile)
147    }
148}
149
150/// Request for a single encode operation.
151#[derive(Debug, Clone)]
152pub struct EncodeRequest {
153    /// Quality setting (0-100, codec-specific interpretation).
154    pub quality: f64,
155
156    /// Additional codec-specific parameters.
157    pub params: HashMap<String, String>,
158}
159
160impl EncodeRequest {
161    /// Create a new encode request with the given quality.
162    #[must_use]
163    pub fn new(quality: f64) -> Self {
164        Self {
165            quality,
166            params: HashMap::new(),
167        }
168    }
169
170    /// Add a codec-specific parameter.
171    #[must_use]
172    pub fn with_param(mut self, key: &str, value: &str) -> Self {
173        self.params.insert(key.to_string(), value.to_string());
174        self
175    }
176}
177
178/// Encode callback type.
179///
180/// Takes image data and encode request, returns encoded bytes.
181pub type EncodeFn = Box<dyn Fn(&ImageData, &EncodeRequest) -> Result<Vec<u8>> + Send + Sync>;
182
183/// Decode callback type.
184///
185/// Takes encoded bytes, returns decoded image data.
186pub type DecodeFn = Box<dyn Fn(&[u8]) -> Result<ImageData> + Send + Sync>;
187
188/// Configuration for an evaluation session.
189#[derive(Debug, Clone)]
190pub struct EvalConfig {
191    /// Directory for report output (CSV, JSON).
192    pub report_dir: PathBuf,
193
194    /// Directory for caching encoded files.
195    pub cache_dir: Option<PathBuf>,
196
197    /// Viewing condition for perceptual metrics.
198    pub viewing: ViewingCondition,
199
200    /// Which metrics to calculate.
201    pub metrics: MetricConfig,
202
203    /// Quality levels to sweep.
204    pub quality_levels: Vec<f64>,
205}
206
207impl EvalConfig {
208    /// Create a new configuration builder.
209    #[must_use]
210    pub fn builder() -> EvalConfigBuilder {
211        EvalConfigBuilder::default()
212    }
213}
214
215/// Builder for [`EvalConfig`].
216#[derive(Debug, Default)]
217pub struct EvalConfigBuilder {
218    report_dir: Option<PathBuf>,
219    cache_dir: Option<PathBuf>,
220    viewing: Option<ViewingCondition>,
221    metrics: Option<MetricConfig>,
222    quality_levels: Option<Vec<f64>>,
223}
224
225impl EvalConfigBuilder {
226    /// Set the report output directory.
227    #[must_use]
228    pub fn report_dir(mut self, path: impl Into<PathBuf>) -> Self {
229        self.report_dir = Some(path.into());
230        self
231    }
232
233    /// Set the cache directory.
234    #[must_use]
235    pub fn cache_dir(mut self, path: impl Into<PathBuf>) -> Self {
236        self.cache_dir = Some(path.into());
237        self
238    }
239
240    /// Set the viewing condition.
241    #[must_use]
242    pub fn viewing(mut self, viewing: ViewingCondition) -> Self {
243        self.viewing = Some(viewing);
244        self
245    }
246
247    /// Set which metrics to calculate.
248    #[must_use]
249    pub fn metrics(mut self, metrics: MetricConfig) -> Self {
250        self.metrics = Some(metrics);
251        self
252    }
253
254    /// Set quality levels to sweep.
255    #[must_use]
256    pub fn quality_levels(mut self, levels: Vec<f64>) -> Self {
257        self.quality_levels = Some(levels);
258        self
259    }
260
261    /// Build the configuration.
262    ///
263    /// # Panics
264    ///
265    /// Panics if `report_dir` is not set.
266    #[must_use]
267    pub fn build(self) -> EvalConfig {
268        EvalConfig {
269            report_dir: self.report_dir.expect("report_dir is required"),
270            cache_dir: self.cache_dir,
271            viewing: self.viewing.unwrap_or_default(),
272            metrics: self.metrics.unwrap_or_else(MetricConfig::all),
273            quality_levels: self
274                .quality_levels
275                .unwrap_or_else(|| vec![50.0, 60.0, 70.0, 80.0, 85.0, 90.0, 95.0]),
276        }
277    }
278}
279
280/// Registered codec entry.
281struct CodecEntry {
282    id: String,
283    version: String,
284    encode: EncodeFn,
285    decode: Option<DecodeFn>,
286}
287
288/// Evaluation session for codec comparison.
289///
290/// # Example
291///
292/// ```rust,ignore
293/// use codec_eval::{EvalSession, EvalConfig, ViewingCondition, ImageData};
294///
295/// let config = EvalConfig::builder()
296///     .report_dir("./reports")
297///     .viewing(ViewingCondition::desktop())
298///     .build();
299///
300/// let mut session = EvalSession::new(config);
301///
302/// session.add_codec("my-codec", "1.0.0", Box::new(|image, request| {
303///     // Encode the image
304///     Ok(encoded_bytes)
305/// }));
306///
307/// let report = session.evaluate_image("test.png", image_data)?;
308/// ```
309pub struct EvalSession {
310    config: EvalConfig,
311    codecs: Vec<CodecEntry>,
312}
313
314impl EvalSession {
315    /// Create a new evaluation session.
316    #[must_use]
317    pub fn new(config: EvalConfig) -> Self {
318        Self {
319            config,
320            codecs: Vec::new(),
321        }
322    }
323
324    /// Register a codec with an encode callback.
325    pub fn add_codec(&mut self, id: &str, version: &str, encode: EncodeFn) -> &mut Self {
326        self.codecs.push(CodecEntry {
327            id: id.to_string(),
328            version: version.to_string(),
329            encode,
330            decode: None,
331        });
332        self
333    }
334
335    /// Register a codec with both encode and decode callbacks.
336    pub fn add_codec_with_decode(
337        &mut self,
338        id: &str,
339        version: &str,
340        encode: EncodeFn,
341        decode: DecodeFn,
342    ) -> &mut Self {
343        self.codecs.push(CodecEntry {
344            id: id.to_string(),
345            version: version.to_string(),
346            encode,
347            decode: Some(decode),
348        });
349        self
350    }
351
352    /// Get the number of registered codecs.
353    #[must_use]
354    pub fn codec_count(&self) -> usize {
355        self.codecs.len()
356    }
357
358    /// Evaluate a single image across all registered codecs.
359    ///
360    /// # Arguments
361    ///
362    /// * `name` - Image name or identifier.
363    /// * `image` - The image data to evaluate.
364    ///
365    /// # Returns
366    ///
367    /// An [`ImageReport`] containing results for all codec/quality combinations.
368    pub fn evaluate_image(&self, name: &str, image: ImageData) -> Result<ImageReport> {
369        let width = image.width() as u32;
370        let height = image.height() as u32;
371        let mut report = ImageReport::new(name.to_string(), width, height);
372
373        let reference_rgb = image.to_rgb8_vec();
374
375        for codec in &self.codecs {
376            for &quality in &self.config.quality_levels {
377                let request = EncodeRequest::new(quality);
378
379                // Encode
380                let start = Instant::now();
381                let encoded = (codec.encode)(&image, &request)?;
382                let encode_time = start.elapsed();
383
384                // Calculate metrics
385                let metrics = if let Some(ref decode) = codec.decode {
386                    // Decode and compare
387                    let start = Instant::now();
388                    let decoded = decode(&encoded)?;
389                    let decode_time = start.elapsed();
390
391                    // Convert decoded pixels to sRGB, applying ICC profile if present.
392                    // This ensures accurate metric calculation for XYB JPEGs and other
393                    // images with embedded ICC profiles.
394                    let decoded_rgb = decoded.to_rgb8_srgb()?;
395                    let metrics =
396                        self.calculate_metrics(&reference_rgb, &decoded_rgb, width, height)?;
397
398                    report.results.push(CodecResult {
399                        codec_id: codec.id.clone(),
400                        codec_version: codec.version.clone(),
401                        quality,
402                        file_size: encoded.len(),
403                        bits_per_pixel: (encoded.len() * 8) as f64 / (width as f64 * height as f64),
404                        encode_time,
405                        decode_time: Some(decode_time),
406                        metrics: metrics.clone(),
407                        perception: metrics.perception_level(),
408                        cached_path: None,
409                        codec_params: request.params,
410                    });
411                    continue;
412                } else {
413                    // No decoder, just record file size
414                    MetricResult::default()
415                };
416
417                report.results.push(CodecResult {
418                    codec_id: codec.id.clone(),
419                    codec_version: codec.version.clone(),
420                    quality,
421                    file_size: encoded.len(),
422                    bits_per_pixel: (encoded.len() * 8) as f64 / (width as f64 * height as f64),
423                    encode_time,
424                    decode_time: None,
425                    metrics,
426                    perception: None,
427                    cached_path: None,
428                    codec_params: request.params,
429                });
430            }
431        }
432
433        Ok(report)
434    }
435
436    /// Calculate metrics between reference and test images.
437    fn calculate_metrics(
438        &self,
439        reference: &[u8],
440        test: &[u8],
441        width: u32,
442        height: u32,
443    ) -> Result<MetricResult> {
444        let mut result = MetricResult::default();
445
446        // Apply XYB roundtrip to reference if enabled
447        let reference_for_metrics: std::borrow::Cow<'_, [u8]> = if self.config.metrics.xyb_roundtrip
448        {
449            std::borrow::Cow::Owned(crate::metrics::xyb_roundtrip(
450                reference,
451                width as usize,
452                height as usize,
453            ))
454        } else {
455            std::borrow::Cow::Borrowed(reference)
456        };
457
458        if self.config.metrics.psnr {
459            result.psnr = Some(calculate_psnr(
460                &reference_for_metrics,
461                test,
462                width as usize,
463                height as usize,
464            ));
465        }
466
467        if self.config.metrics.dssim {
468            let ref_img =
469                rgb8_to_dssim_image(&reference_for_metrics, width as usize, height as usize);
470            let test_img = rgb8_to_dssim_image(test, width as usize, height as usize);
471            result.dssim = Some(crate::metrics::dssim::calculate_dssim(
472                &ref_img,
473                &test_img,
474                &self.config.viewing,
475            )?);
476        }
477
478        if self.config.metrics.ssimulacra2 {
479            result.ssimulacra2 = Some(crate::metrics::ssimulacra2::calculate_ssimulacra2(
480                &reference_for_metrics,
481                test,
482                width as usize,
483                height as usize,
484            )?);
485        }
486
487        if self.config.metrics.butteraugli {
488            result.butteraugli = Some(crate::metrics::butteraugli::calculate_butteraugli(
489                &reference_for_metrics,
490                test,
491                width as usize,
492                height as usize,
493            )?);
494        }
495
496        Ok(result)
497    }
498
499    /// Write an image report to the configured report directory.
500    pub fn write_image_report(&self, report: &ImageReport) -> Result<()> {
501        std::fs::create_dir_all(&self.config.report_dir)?;
502
503        let json_path = self.config.report_dir.join(format!("{}.json", report.name));
504        let json = serde_json::to_string_pretty(report)?;
505        std::fs::write(json_path, json)?;
506
507        Ok(())
508    }
509
510    /// Write a corpus report to the configured report directory.
511    pub fn write_corpus_report(&self, report: &CorpusReport) -> Result<()> {
512        std::fs::create_dir_all(&self.config.report_dir)?;
513
514        let json_path = self.config.report_dir.join(format!("{}.json", report.name));
515        let json = serde_json::to_string_pretty(report)?;
516        std::fs::write(json_path, json)?;
517
518        // Also write CSV summary
519        let csv_path = self.config.report_dir.join(format!("{}.csv", report.name));
520        self.write_csv_summary(report, &csv_path)?;
521
522        Ok(())
523    }
524
525    /// Write a CSV summary of the corpus report.
526    fn write_csv_summary(&self, report: &CorpusReport, path: &Path) -> Result<()> {
527        let mut wtr = csv::Writer::from_path(path)?;
528
529        // Header
530        wtr.write_record([
531            "image",
532            "codec",
533            "version",
534            "quality",
535            "file_size",
536            "bpp",
537            "encode_ms",
538            "decode_ms",
539            "dssim",
540            "ssimulacra2",
541            "butteraugli",
542            "psnr",
543            "perception",
544        ])?;
545
546        for img in &report.images {
547            for result in &img.results {
548                wtr.write_record([
549                    &img.name,
550                    &result.codec_id,
551                    &result.codec_version,
552                    &result.quality.to_string(),
553                    &result.file_size.to_string(),
554                    &format!("{:.4}", result.bits_per_pixel),
555                    &result.encode_time.as_millis().to_string(),
556                    &result
557                        .decode_time
558                        .map_or(String::new(), |d| d.as_millis().to_string()),
559                    &result
560                        .metrics
561                        .dssim
562                        .map_or(String::new(), |d| format!("{:.6}", d)),
563                    &result
564                        .metrics
565                        .ssimulacra2
566                        .map_or(String::new(), |s| format!("{:.2}", s)),
567                    &result
568                        .metrics
569                        .butteraugli
570                        .map_or(String::new(), |b| format!("{:.4}", b)),
571                    &result
572                        .metrics
573                        .psnr
574                        .map_or(String::new(), |p| format!("{:.2}", p)),
575                    &result
576                        .perception
577                        .map_or(String::new(), |p| p.code().to_string()),
578                ])?;
579            }
580        }
581
582        wtr.flush()?;
583        Ok(())
584    }
585}
586
587#[cfg(test)]
588mod tests {
589    use super::*;
590
591    fn create_test_image(width: usize, height: usize) -> ImageData {
592        let data: Vec<u8> = (0..width * height * 3).map(|i| (i % 256) as u8).collect();
593        ImageData::RgbSlice {
594            data,
595            width,
596            height,
597        }
598    }
599
600    #[test]
601    fn test_image_data_dimensions() {
602        let img = create_test_image(100, 50);
603        assert_eq!(img.width(), 100);
604        assert_eq!(img.height(), 50);
605    }
606
607    #[test]
608    fn test_encode_request() {
609        let req = EncodeRequest::new(80.0).with_param("subsampling", "4:2:0");
610        assert!((req.quality - 80.0).abs() < f64::EPSILON);
611        assert_eq!(req.params.get("subsampling"), Some(&"4:2:0".to_string()));
612    }
613
614    #[test]
615    fn test_eval_config_builder() {
616        let config = EvalConfig::builder()
617            .report_dir("/tmp/reports")
618            .cache_dir("/tmp/cache")
619            .viewing(ViewingCondition::laptop())
620            .quality_levels(vec![50.0, 75.0, 90.0])
621            .build();
622
623        assert_eq!(config.report_dir, PathBuf::from("/tmp/reports"));
624        assert_eq!(config.cache_dir, Some(PathBuf::from("/tmp/cache")));
625        assert!((config.viewing.acuity_ppd - 60.0).abs() < f64::EPSILON);
626        assert_eq!(config.quality_levels.len(), 3);
627    }
628
629    #[test]
630    fn test_session_add_codec() {
631        let config = EvalConfig::builder().report_dir("/tmp/test").build();
632
633        let mut session = EvalSession::new(config);
634        session.add_codec("test", "1.0", Box::new(|_, _| Ok(vec![0u8; 100])));
635
636        assert_eq!(session.codec_count(), 1);
637    }
638}