codec_eval/eval/
session.rs

1//! Evaluation session with callback-based codec interface.
2//!
3//! This module provides [`EvalSession`], the main entry point for codec evaluation.
4//! External crates provide encode/decode callbacks, and the session handles
5//! metrics calculation, caching, and report generation.
6
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9use std::time::Instant;
10
11use imgref::ImgVec;
12use rgb::{RGB8, RGBA8};
13
14use crate::error::Result;
15use crate::eval::report::{CodecResult, CorpusReport, ImageReport};
16use crate::metrics::dssim::rgb8_to_dssim_image;
17use crate::metrics::{MetricConfig, MetricResult, calculate_psnr};
18use crate::viewing::ViewingCondition;
19
20/// Image data accepted by the evaluation session.
21///
22/// Supports both `imgref::ImgVec` types and raw slices for flexibility.
23#[derive(Clone)]
24pub enum ImageData {
25    /// RGB8 image using imgref.
26    Rgb8(ImgVec<RGB8>),
27
28    /// RGBA8 image using imgref.
29    Rgba8(ImgVec<RGBA8>),
30
31    /// RGB8 raw slice with dimensions.
32    RgbSlice {
33        /// Pixel data in row-major order.
34        data: Vec<u8>,
35        /// Image width.
36        width: usize,
37        /// Image height.
38        height: usize,
39    },
40
41    /// RGBA8 raw slice with dimensions.
42    RgbaSlice {
43        /// Pixel data in row-major order.
44        data: Vec<u8>,
45        /// Image width.
46        width: usize,
47        /// Image height.
48        height: usize,
49    },
50}
51
52impl ImageData {
53    /// Get image width.
54    #[must_use]
55    pub fn width(&self) -> usize {
56        match self {
57            Self::Rgb8(img) => img.width(),
58            Self::Rgba8(img) => img.width(),
59            Self::RgbSlice { width, .. } => *width,
60            Self::RgbaSlice { width, .. } => *width,
61        }
62    }
63
64    /// Get image height.
65    #[must_use]
66    pub fn height(&self) -> usize {
67        match self {
68            Self::Rgb8(img) => img.height(),
69            Self::Rgba8(img) => img.height(),
70            Self::RgbSlice { height, .. } => *height,
71            Self::RgbaSlice { height, .. } => *height,
72        }
73    }
74
75    /// Convert to RGB8 slice representation.
76    #[must_use]
77    pub fn to_rgb8_vec(&self) -> Vec<u8> {
78        match self {
79            Self::Rgb8(img) => img.pixels().flat_map(|p| [p.r, p.g, p.b]).collect(),
80            Self::Rgba8(img) => img.pixels().flat_map(|p| [p.r, p.g, p.b]).collect(),
81            Self::RgbSlice { data, .. } => data.clone(),
82            Self::RgbaSlice {
83                data,
84                width,
85                height,
86            } => {
87                let mut rgb = Vec::with_capacity(width * height * 3);
88                for chunk in data.chunks_exact(4) {
89                    rgb.push(chunk[0]);
90                    rgb.push(chunk[1]);
91                    rgb.push(chunk[2]);
92                }
93                rgb
94            }
95        }
96    }
97}
98
99/// Request for a single encode operation.
100#[derive(Debug, Clone)]
101pub struct EncodeRequest {
102    /// Quality setting (0-100, codec-specific interpretation).
103    pub quality: f64,
104
105    /// Additional codec-specific parameters.
106    pub params: HashMap<String, String>,
107}
108
109impl EncodeRequest {
110    /// Create a new encode request with the given quality.
111    #[must_use]
112    pub fn new(quality: f64) -> Self {
113        Self {
114            quality,
115            params: HashMap::new(),
116        }
117    }
118
119    /// Add a codec-specific parameter.
120    #[must_use]
121    pub fn with_param(mut self, key: &str, value: &str) -> Self {
122        self.params.insert(key.to_string(), value.to_string());
123        self
124    }
125}
126
127/// Encode callback type.
128///
129/// Takes image data and encode request, returns encoded bytes.
130pub type EncodeFn = Box<dyn Fn(&ImageData, &EncodeRequest) -> Result<Vec<u8>> + Send + Sync>;
131
132/// Decode callback type.
133///
134/// Takes encoded bytes, returns decoded image data.
135pub type DecodeFn = Box<dyn Fn(&[u8]) -> Result<ImageData> + Send + Sync>;
136
137/// Configuration for an evaluation session.
138#[derive(Debug, Clone)]
139pub struct EvalConfig {
140    /// Directory for report output (CSV, JSON).
141    pub report_dir: PathBuf,
142
143    /// Directory for caching encoded files.
144    pub cache_dir: Option<PathBuf>,
145
146    /// Viewing condition for perceptual metrics.
147    pub viewing: ViewingCondition,
148
149    /// Which metrics to calculate.
150    pub metrics: MetricConfig,
151
152    /// Quality levels to sweep.
153    pub quality_levels: Vec<f64>,
154}
155
156impl EvalConfig {
157    /// Create a new configuration builder.
158    #[must_use]
159    pub fn builder() -> EvalConfigBuilder {
160        EvalConfigBuilder::default()
161    }
162}
163
164/// Builder for [`EvalConfig`].
165#[derive(Debug, Default)]
166pub struct EvalConfigBuilder {
167    report_dir: Option<PathBuf>,
168    cache_dir: Option<PathBuf>,
169    viewing: Option<ViewingCondition>,
170    metrics: Option<MetricConfig>,
171    quality_levels: Option<Vec<f64>>,
172}
173
174impl EvalConfigBuilder {
175    /// Set the report output directory.
176    #[must_use]
177    pub fn report_dir(mut self, path: impl Into<PathBuf>) -> Self {
178        self.report_dir = Some(path.into());
179        self
180    }
181
182    /// Set the cache directory.
183    #[must_use]
184    pub fn cache_dir(mut self, path: impl Into<PathBuf>) -> Self {
185        self.cache_dir = Some(path.into());
186        self
187    }
188
189    /// Set the viewing condition.
190    #[must_use]
191    pub fn viewing(mut self, viewing: ViewingCondition) -> Self {
192        self.viewing = Some(viewing);
193        self
194    }
195
196    /// Set which metrics to calculate.
197    #[must_use]
198    pub fn metrics(mut self, metrics: MetricConfig) -> Self {
199        self.metrics = Some(metrics);
200        self
201    }
202
203    /// Set quality levels to sweep.
204    #[must_use]
205    pub fn quality_levels(mut self, levels: Vec<f64>) -> Self {
206        self.quality_levels = Some(levels);
207        self
208    }
209
210    /// Build the configuration.
211    ///
212    /// # Panics
213    ///
214    /// Panics if `report_dir` is not set.
215    #[must_use]
216    pub fn build(self) -> EvalConfig {
217        EvalConfig {
218            report_dir: self.report_dir.expect("report_dir is required"),
219            cache_dir: self.cache_dir,
220            viewing: self.viewing.unwrap_or_default(),
221            metrics: self.metrics.unwrap_or_else(MetricConfig::all),
222            quality_levels: self
223                .quality_levels
224                .unwrap_or_else(|| vec![50.0, 60.0, 70.0, 80.0, 85.0, 90.0, 95.0]),
225        }
226    }
227}
228
229/// Registered codec entry.
230struct CodecEntry {
231    id: String,
232    version: String,
233    encode: EncodeFn,
234    decode: Option<DecodeFn>,
235}
236
237/// Evaluation session for codec comparison.
238///
239/// # Example
240///
241/// ```rust,ignore
242/// use codec_eval::{EvalSession, EvalConfig, ViewingCondition, ImageData};
243///
244/// let config = EvalConfig::builder()
245///     .report_dir("./reports")
246///     .viewing(ViewingCondition::desktop())
247///     .build();
248///
249/// let mut session = EvalSession::new(config);
250///
251/// session.add_codec("my-codec", "1.0.0", Box::new(|image, request| {
252///     // Encode the image
253///     Ok(encoded_bytes)
254/// }));
255///
256/// let report = session.evaluate_image("test.png", image_data)?;
257/// ```
258pub struct EvalSession {
259    config: EvalConfig,
260    codecs: Vec<CodecEntry>,
261}
262
263impl EvalSession {
264    /// Create a new evaluation session.
265    #[must_use]
266    pub fn new(config: EvalConfig) -> Self {
267        Self {
268            config,
269            codecs: Vec::new(),
270        }
271    }
272
273    /// Register a codec with an encode callback.
274    pub fn add_codec(&mut self, id: &str, version: &str, encode: EncodeFn) -> &mut Self {
275        self.codecs.push(CodecEntry {
276            id: id.to_string(),
277            version: version.to_string(),
278            encode,
279            decode: None,
280        });
281        self
282    }
283
284    /// Register a codec with both encode and decode callbacks.
285    pub fn add_codec_with_decode(
286        &mut self,
287        id: &str,
288        version: &str,
289        encode: EncodeFn,
290        decode: DecodeFn,
291    ) -> &mut Self {
292        self.codecs.push(CodecEntry {
293            id: id.to_string(),
294            version: version.to_string(),
295            encode,
296            decode: Some(decode),
297        });
298        self
299    }
300
301    /// Get the number of registered codecs.
302    #[must_use]
303    pub fn codec_count(&self) -> usize {
304        self.codecs.len()
305    }
306
307    /// Evaluate a single image across all registered codecs.
308    ///
309    /// # Arguments
310    ///
311    /// * `name` - Image name or identifier.
312    /// * `image` - The image data to evaluate.
313    ///
314    /// # Returns
315    ///
316    /// An [`ImageReport`] containing results for all codec/quality combinations.
317    pub fn evaluate_image(&self, name: &str, image: ImageData) -> Result<ImageReport> {
318        let width = image.width() as u32;
319        let height = image.height() as u32;
320        let mut report = ImageReport::new(name.to_string(), width, height);
321
322        let reference_rgb = image.to_rgb8_vec();
323
324        for codec in &self.codecs {
325            for &quality in &self.config.quality_levels {
326                let request = EncodeRequest::new(quality);
327
328                // Encode
329                let start = Instant::now();
330                let encoded = (codec.encode)(&image, &request)?;
331                let encode_time = start.elapsed();
332
333                // Calculate metrics
334                let metrics = if let Some(ref decode) = codec.decode {
335                    // Decode and compare
336                    let start = Instant::now();
337                    let decoded = decode(&encoded)?;
338                    let decode_time = start.elapsed();
339
340                    let decoded_rgb = decoded.to_rgb8_vec();
341                    let metrics =
342                        self.calculate_metrics(&reference_rgb, &decoded_rgb, width, height)?;
343
344                    report.results.push(CodecResult {
345                        codec_id: codec.id.clone(),
346                        codec_version: codec.version.clone(),
347                        quality,
348                        file_size: encoded.len(),
349                        bits_per_pixel: (encoded.len() * 8) as f64 / (width as f64 * height as f64),
350                        encode_time,
351                        decode_time: Some(decode_time),
352                        metrics: metrics.clone(),
353                        perception: metrics.perception_level(),
354                        cached_path: None,
355                        codec_params: request.params,
356                    });
357                    continue;
358                } else {
359                    // No decoder, just record file size
360                    MetricResult::default()
361                };
362
363                report.results.push(CodecResult {
364                    codec_id: codec.id.clone(),
365                    codec_version: codec.version.clone(),
366                    quality,
367                    file_size: encoded.len(),
368                    bits_per_pixel: (encoded.len() * 8) as f64 / (width as f64 * height as f64),
369                    encode_time,
370                    decode_time: None,
371                    metrics,
372                    perception: None,
373                    cached_path: None,
374                    codec_params: request.params,
375                });
376            }
377        }
378
379        Ok(report)
380    }
381
382    /// Calculate metrics between reference and test images.
383    fn calculate_metrics(
384        &self,
385        reference: &[u8],
386        test: &[u8],
387        width: u32,
388        height: u32,
389    ) -> Result<MetricResult> {
390        let mut result = MetricResult::default();
391
392        if self.config.metrics.psnr {
393            result.psnr = Some(calculate_psnr(
394                reference,
395                test,
396                width as usize,
397                height as usize,
398            ));
399        }
400
401        if self.config.metrics.dssim {
402            let ref_img = rgb8_to_dssim_image(reference, width as usize, height as usize);
403            let test_img = rgb8_to_dssim_image(test, width as usize, height as usize);
404            result.dssim = Some(crate::metrics::dssim::calculate_dssim(
405                &ref_img,
406                &test_img,
407                &self.config.viewing,
408            )?);
409        }
410
411        if self.config.metrics.ssimulacra2 {
412            result.ssimulacra2 = Some(crate::metrics::ssimulacra2::calculate_ssimulacra2(
413                reference,
414                test,
415                width as usize,
416                height as usize,
417            )?);
418        }
419
420        if self.config.metrics.butteraugli {
421            result.butteraugli = Some(crate::metrics::butteraugli::calculate_butteraugli(
422                reference,
423                test,
424                width as usize,
425                height as usize,
426            )?);
427        }
428
429        Ok(result)
430    }
431
432    /// Write an image report to the configured report directory.
433    pub fn write_image_report(&self, report: &ImageReport) -> Result<()> {
434        std::fs::create_dir_all(&self.config.report_dir)?;
435
436        let json_path = self.config.report_dir.join(format!("{}.json", report.name));
437        let json = serde_json::to_string_pretty(report)?;
438        std::fs::write(json_path, json)?;
439
440        Ok(())
441    }
442
443    /// Write a corpus report to the configured report directory.
444    pub fn write_corpus_report(&self, report: &CorpusReport) -> Result<()> {
445        std::fs::create_dir_all(&self.config.report_dir)?;
446
447        let json_path = self.config.report_dir.join(format!("{}.json", report.name));
448        let json = serde_json::to_string_pretty(report)?;
449        std::fs::write(json_path, json)?;
450
451        // Also write CSV summary
452        let csv_path = self.config.report_dir.join(format!("{}.csv", report.name));
453        self.write_csv_summary(report, &csv_path)?;
454
455        Ok(())
456    }
457
458    /// Write a CSV summary of the corpus report.
459    fn write_csv_summary(&self, report: &CorpusReport, path: &Path) -> Result<()> {
460        let mut wtr = csv::Writer::from_path(path)?;
461
462        // Header
463        wtr.write_record([
464            "image",
465            "codec",
466            "version",
467            "quality",
468            "file_size",
469            "bpp",
470            "encode_ms",
471            "decode_ms",
472            "dssim",
473            "ssimulacra2",
474            "butteraugli",
475            "psnr",
476            "perception",
477        ])?;
478
479        for img in &report.images {
480            for result in &img.results {
481                wtr.write_record([
482                    &img.name,
483                    &result.codec_id,
484                    &result.codec_version,
485                    &result.quality.to_string(),
486                    &result.file_size.to_string(),
487                    &format!("{:.4}", result.bits_per_pixel),
488                    &result.encode_time.as_millis().to_string(),
489                    &result
490                        .decode_time
491                        .map_or(String::new(), |d| d.as_millis().to_string()),
492                    &result
493                        .metrics
494                        .dssim
495                        .map_or(String::new(), |d| format!("{:.6}", d)),
496                    &result
497                        .metrics
498                        .ssimulacra2
499                        .map_or(String::new(), |s| format!("{:.2}", s)),
500                    &result
501                        .metrics
502                        .butteraugli
503                        .map_or(String::new(), |b| format!("{:.4}", b)),
504                    &result
505                        .metrics
506                        .psnr
507                        .map_or(String::new(), |p| format!("{:.2}", p)),
508                    &result
509                        .perception
510                        .map_or(String::new(), |p| p.code().to_string()),
511                ])?;
512            }
513        }
514
515        wtr.flush()?;
516        Ok(())
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523
524    fn create_test_image(width: usize, height: usize) -> ImageData {
525        let data: Vec<u8> = (0..width * height * 3).map(|i| (i % 256) as u8).collect();
526        ImageData::RgbSlice {
527            data,
528            width,
529            height,
530        }
531    }
532
533    #[test]
534    fn test_image_data_dimensions() {
535        let img = create_test_image(100, 50);
536        assert_eq!(img.width(), 100);
537        assert_eq!(img.height(), 50);
538    }
539
540    #[test]
541    fn test_encode_request() {
542        let req = EncodeRequest::new(80.0).with_param("subsampling", "4:2:0");
543        assert!((req.quality - 80.0).abs() < f64::EPSILON);
544        assert_eq!(req.params.get("subsampling"), Some(&"4:2:0".to_string()));
545    }
546
547    #[test]
548    fn test_eval_config_builder() {
549        let config = EvalConfig::builder()
550            .report_dir("/tmp/reports")
551            .cache_dir("/tmp/cache")
552            .viewing(ViewingCondition::laptop())
553            .quality_levels(vec![50.0, 75.0, 90.0])
554            .build();
555
556        assert_eq!(config.report_dir, PathBuf::from("/tmp/reports"));
557        assert_eq!(config.cache_dir, Some(PathBuf::from("/tmp/cache")));
558        assert!((config.viewing.acuity_ppd - 60.0).abs() < f64::EPSILON);
559        assert_eq!(config.quality_levels.len(), 3);
560    }
561
562    #[test]
563    fn test_session_add_codec() {
564        let config = EvalConfig::builder().report_dir("/tmp/test").build();
565
566        let mut session = EvalSession::new(config);
567        session.add_codec("test", "1.0", Box::new(|_, _| Ok(vec![0u8; 100])));
568
569        assert_eq!(session.codec_count(), 1);
570    }
571}