Skip to main content

chromaframe_sdk/
image.rs

1use crate::color::{Lab, directional_delta, srgb_to_lab};
2use crate::quality::ImageError;
3use crate::score::{
4    ScoreError, harshness, quality_factor, region_factor, sample_factor, score_candidate,
5    try_confidence,
6};
7use crate::types::{
8    CandidateColor, CandidateRanking, CaptureQualityReport, ContrastMap, GoalVector,
9    MeasurementStatus, SdkError, SubjectColorProfile,
10};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use std::fmt;
14use thiserror::Error;
15
16#[derive(Debug, Error)]
17pub enum MeasureError {
18    #[error(transparent)]
19    Image(#[from] ImageError),
20    #[error("region '{0}' has zero area or no accepted pixels")]
21    EmptyRegion(&'static str),
22    #[error("region '{0}' is out of bounds")]
23    RegionOutOfBounds(&'static str),
24    #[error("skin measurement is required for ranking")]
25    MissingSkin,
26    #[error(transparent)]
27    Validation(#[from] SdkError),
28    #[error("quality report contains non-finite value in {0}")]
29    NonFiniteQuality(&'static str),
30    #[error(transparent)]
31    Score(#[from] ScoreError),
32}
33
34#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
35#[serde(tag = "status", rename_all = "snake_case")]
36pub enum RegionMeasurement {
37    Measured { lab: Lab, accepted_pixels: usize },
38    Missing { reason: String },
39}
40
41#[derive(Clone)]
42pub struct ManualRegionSamples {
43    pub skin: Option<Vec<[u8; 3]>>,
44    pub brow: Option<Vec<[u8; 3]>>,
45    pub iris: Option<Vec<[u8; 3]>>,
46    pub sclera: Option<Vec<[u8; 3]>>,
47    pub lip: Option<Vec<[u8; 3]>>,
48    pub hair: Option<Vec<[u8; 3]>>,
49    pub beard: Option<Vec<[u8; 3]>>,
50    pub clothing: Option<Vec<[u8; 3]>>,
51}
52
53impl fmt::Debug for ManualRegionSamples {
54    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
55        formatter
56            .debug_struct("ManualRegionSamples")
57            .field("skin", &redacted_sample_count(&self.skin))
58            .field("brow", &redacted_sample_count(&self.brow))
59            .field("iris", &redacted_sample_count(&self.iris))
60            .field("sclera", &redacted_sample_count(&self.sclera))
61            .field("lip", &redacted_sample_count(&self.lip))
62            .field("hair", &redacted_sample_count(&self.hair))
63            .field("beard", &redacted_sample_count(&self.beard))
64            .field("clothing", &redacted_sample_count(&self.clothing))
65            .finish()
66    }
67}
68
69#[derive(Clone)]
70pub struct MeasurementInput {
71    pub quality: CaptureQualityReport,
72    pub mode: crate::types::MeasurementMode,
73    pub goal_vector: GoalVector,
74    pub candidates: Vec<CandidateColor>,
75    pub samples: ManualRegionSamples,
76}
77
78impl fmt::Debug for MeasurementInput {
79    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
80        formatter
81            .debug_struct("MeasurementInput")
82            .field("quality", &self.quality)
83            .field("mode", &self.mode)
84            .field("goal_vector", &self.goal_vector)
85            .field("candidates", &self.candidates)
86            .field("samples", &self.samples)
87            .finish()
88    }
89}
90
91#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
92pub struct MeasurementReport {
93    pub status: MeasurementStatus,
94    pub quality: CaptureQualityReport,
95    pub subject: Option<SubjectColorProfile>,
96    pub contrast_map: ContrastMap,
97    pub rankings: Vec<CandidateRanking>,
98}
99
100pub struct MeasurementEngine;
101
102impl MeasurementEngine {
103    pub fn measure(input: &MeasurementInput) -> Result<MeasurementReport, MeasureError> {
104        validate_measurement_input(input)?;
105        let skin = measure_region("skin", input.samples.skin.as_deref(), true)?;
106        let brow = measure_region("brow", input.samples.brow.as_deref(), false)?;
107        let iris = measure_region("iris", input.samples.iris.as_deref(), false)?;
108        let sclera = measure_region("sclera", input.samples.sclera.as_deref(), false)?;
109        let lip = measure_region("lip", input.samples.lip.as_deref(), false)?;
110        let hair = measure_region("hair", input.samples.hair.as_deref(), false)?;
111        let beard = measure_region("beard", input.samples.beard.as_deref(), false)?;
112        let clothing = measure_region("clothing", input.samples.clothing.as_deref(), false)?;
113        let RegionMeasurement::Measured {
114            lab: skin_lab,
115            accepted_pixels: skin_count,
116        } = skin
117        else {
118            return Ok(MeasurementReport {
119                status: MeasurementStatus::InsufficientData,
120                quality: input.quality.clone(),
121                subject: None,
122                contrast_map: ContrastMap {
123                    contrasts: Vec::new(),
124                },
125                rankings: Vec::new(),
126            });
127        };
128        let mut contrasts = Vec::new();
129        for (name, region) in [
130            ("brow", &brow),
131            ("iris", &iris),
132            ("sclera", &sclera),
133            ("lip", &lip),
134            ("hair", &hair),
135            ("beard", &beard),
136            ("clothing", &clothing),
137        ] {
138            if let RegionMeasurement::Measured { lab, .. } = region {
139                contrasts.push(directional_delta("skin", skin_lab, name, *lab));
140            }
141        }
142        let baseline_feature_michelson = contrasts
143            .iter()
144            .map(|d| d.michelson_lightness_contrast)
145            .max_by(f32::total_cmp);
146        let available_weight =
147            0.40 + if !input.candidates.is_empty() {
148                0.20
149            } else {
150                0.0
151            } + if matches!(iris, RegionMeasurement::Measured { .. }) {
152                0.10
153            } else {
154                0.0
155            } + if matches!(lip, RegionMeasurement::Measured { .. }) {
156                0.10
157            } else {
158                0.0
159            } + if matches!(hair, RegionMeasurement::Measured { .. })
160                || matches!(brow, RegionMeasurement::Measured { .. })
161                || matches!(beard, RegionMeasurement::Measured { .. })
162            {
163                0.10
164            } else {
165                0.0
166            } + if matches!(clothing, RegionMeasurement::Measured { .. }) {
167                0.10
168            } else {
169                0.0
170            };
171        let confidence_value = try_confidence(
172            input.mode,
173            quality_factor(&input.quality),
174            region_factor(available_weight, 1.0),
175            sample_factor(&[
176                ("skin", skin_count),
177                ("brow", count(&brow)),
178                ("iris", count(&iris)),
179                ("lip", count(&lip)),
180                ("hair", count(&hair)),
181                ("beard", count(&beard)),
182                ("clothing", count(&clothing)),
183            ]),
184        )?;
185        let mut rankings = input
186            .candidates
187            .iter()
188            .enumerate()
189            .map(|(index, candidate)| {
190                let candidate_lab = srgb_to_lab(candidate.srgb);
191                let candidate_harshness =
192                    harshness(crate::color::delta_e00(candidate_lab, skin_lab))?;
193                let (score, components) = score_candidate(crate::score::CandidateScoreInput {
194                    skin_lab,
195                    candidate_lab,
196                    lip_lab: lab(&lip),
197                    iris_lab: lab(&iris),
198                    sclera_lab: lab(&sclera),
199                    hair_lab: lab(&hair),
200                    brow_lab: lab(&brow),
201                    baseline_feature_michelson,
202                    goal_vector: input.goal_vector.clone(),
203                    confidence: confidence_value,
204                })?;
205                Ok((
206                    index,
207                    CandidateRanking {
208                        name: candidate.name.clone(),
209                        score,
210                        confidence: confidence_value,
211                        harshness: candidate_harshness,
212                        label: "Goal-specific fit with measured uncertainty".to_string(),
213                        components,
214                    },
215                ))
216            })
217            .collect::<Result<Vec<_>, MeasureError>>()?;
218        sort_rankings_by_contract(&mut rankings);
219        let rankings = rankings.into_iter().map(|(_, ranking)| ranking).collect();
220        Ok(MeasurementReport {
221            status: MeasurementStatus::Complete,
222            quality: input.quality.clone(),
223            subject: Some(SubjectColorProfile {
224                skin_lab,
225                skin_ita: crate::color::ita_degrees(skin_lab),
226                skin_depth_proxy: crate::color::depth_proxy(skin_lab),
227            }),
228            contrast_map: ContrastMap { contrasts },
229            rankings,
230        })
231    }
232}
233
234fn validate_measurement_input(input: &MeasurementInput) -> Result<(), MeasureError> {
235    input.goal_vector.clone().parse()?;
236    if input.candidates.is_empty() {
237        return Err(SdkError::EmptyCandidates.into());
238    }
239    for candidate in &input.candidates {
240        if candidate.name.trim().is_empty() {
241            return Err(SdkError::InvalidCandidateColor {
242                name: candidate.name.clone(),
243            }
244            .into());
245        }
246    }
247    validate_quality_report(&input.quality)
248}
249
250fn sort_rankings_by_contract(rankings: &mut [(usize, CandidateRanking)]) {
251    rankings.sort_by(|left, right| {
252        right
253            .1
254            .score
255            .total_cmp(&left.1.score)
256            .then_with(|| right.1.confidence.total_cmp(&left.1.confidence))
257            .then_with(|| left.1.harshness.total_cmp(&right.1.harshness))
258            .then_with(|| left.0.cmp(&right.0))
259    });
260}
261
262fn validate_quality_report(report: &CaptureQualityReport) -> Result<(), MeasureError> {
263    finite_quality("over_clip_fraction", report.over_clip_fraction)?;
264    finite_quality("under_clip_fraction", report.under_clip_fraction)?;
265    finite_quality_check("white_balance", &report.white_balance)?;
266    finite_quality_check("blur", &report.blur)?;
267    finite_quality_check("shadow", &report.shadow)?;
268    finite_quality_check("face_angle", &report.face_angle)?;
269    finite_bool_quality_check("filters_or_makeup", &report.filters_or_makeup)?;
270    finite_bool_quality_check("occlusion", &report.occlusion)?;
271    finite_bool_quality_check("calibration_card", &report.calibration_card)?;
272    Ok(())
273}
274
275fn finite_quality(field: &'static str, value: f32) -> Result<(), MeasureError> {
276    if value.is_finite() {
277        return Ok(());
278    }
279    Err(MeasureError::NonFiniteQuality(field))
280}
281
282fn finite_quality_check(
283    field: &'static str,
284    check: &crate::types::QualityCheck<f32>,
285) -> Result<(), MeasureError> {
286    match check {
287        crate::types::QualityCheck::Measured { value, deduction } => {
288            finite_quality(field, *value)?;
289            finite_quality(field, *deduction)
290        }
291        crate::types::QualityCheck::NotMeasured { deduction, .. } => {
292            finite_quality(field, *deduction)
293        }
294    }
295}
296
297fn finite_bool_quality_check(
298    field: &'static str,
299    check: &crate::types::QualityCheck<bool>,
300) -> Result<(), MeasureError> {
301    match check {
302        crate::types::QualityCheck::Measured { deduction, .. }
303        | crate::types::QualityCheck::NotMeasured { deduction, .. } => {
304            finite_quality(field, *deduction)
305        }
306    }
307}
308
309fn redacted_sample_count(samples: &Option<Vec<[u8; 3]>>) -> String {
310    samples.as_ref().map_or_else(
311        || "[MISSING]".to_string(),
312        |values| format!("[REDACTED; count={}]", values.len()),
313    )
314}
315
316fn measure_region(
317    name: &'static str,
318    samples: Option<&[[u8; 3]]>,
319    required: bool,
320) -> Result<RegionMeasurement, MeasureError> {
321    let Some(samples) = samples else {
322        return Ok(RegionMeasurement::Missing {
323            reason: "not_provided".to_string(),
324        });
325    };
326    if samples.is_empty() && required {
327        return Err(MeasureError::EmptyRegion(name));
328    }
329    if samples.is_empty() {
330        return Ok(RegionMeasurement::Missing {
331            reason: "zero_accepted_pixels".to_string(),
332        });
333    }
334    let labs: Vec<_> = samples.iter().map(|rgb| srgb_to_lab(*rgb)).collect();
335    if labs
336        .iter()
337        .any(|lab| !lab.l.is_finite() || !lab.a.is_finite() || !lab.b.is_finite())
338    {
339        return Err(MeasureError::EmptyRegion(name));
340    }
341    let count = labs.len();
342    let (l, a, b) = labs.iter().fold((0.0, 0.0, 0.0), |acc, lab| {
343        (acc.0 + lab.l, acc.1 + lab.a, acc.2 + lab.b)
344    });
345    Ok(RegionMeasurement::Measured {
346        lab: Lab {
347            l: l / count as f32,
348            a: a / count as f32,
349            b: b / count as f32,
350        },
351        accepted_pixels: count,
352    })
353}
354fn count(region: &RegionMeasurement) -> usize {
355    if let RegionMeasurement::Measured {
356        accepted_pixels, ..
357    } = region
358    {
359        *accepted_pixels
360    } else {
361        0
362    }
363}
364fn lab(region: &RegionMeasurement) -> Option<Lab> {
365    if let RegionMeasurement::Measured { lab, .. } = region {
366        Some(*lab)
367    } else {
368        None
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375    use crate::score::ScoreComponents;
376
377    fn ranking(name: &str, score: f32, confidence: f32, harshness: f32) -> CandidateRanking {
378        CandidateRanking {
379            name: name.to_string(),
380            score,
381            confidence,
382            harshness,
383            label: "test".to_string(),
384            components: ScoreComponents {
385                skin_quality: 0.5,
386                feature_readability: 0.5,
387                eye_support: 0.5,
388                lip_skin_harmony: 0.5,
389                hair_brow_coherence: 0.5,
390                goal_alignment: 0.5,
391                total_penalty: 0.0,
392            },
393        }
394    }
395
396    #[test]
397    fn ranking_sort_uses_lower_harshness_then_stable_order() {
398        let mut rankings = vec![
399            (0, ranking("higher_harshness", 42.0, 0.8, 0.10)),
400            (1, ranking("lower_harshness", 42.0, 0.8, 0.01)),
401            (2, ranking("stable_a", 40.0, 0.8, 0.01)),
402            (3, ranking("stable_b", 40.0, 0.8, 0.01)),
403        ];
404        sort_rankings_by_contract(&mut rankings);
405        let names: Vec<_> = rankings
406            .iter()
407            .map(|(_, ranking)| ranking.name.as_str())
408            .collect();
409        assert_eq!(
410            names,
411            vec![
412                "lower_harshness",
413                "higher_harshness",
414                "stable_a",
415                "stable_b"
416            ]
417        );
418    }
419}