Skip to main content

ccalc_plot/
colormap.rs

1//! Colormap LUT data and imagesc rendering (ASCII and SVG/PNG).
2
3#[cfg(feature = "plot-svg")]
4use plotters::prelude::*;
5
6#[cfg(any(feature = "plot", feature = "plot-svg"))]
7use crate::FigureState;
8
9// ── Public API ─────────────────────────────────────────────────────────────
10
11/// All supported colormap names.
12pub const VALID_COLORMAPS: &[&str] = &[
13    "viridis", "inferno", "magma", "plasma", "hot", "cool", "jet", "gray",
14];
15
16/// Validates a colormap name.
17///
18/// Returns `Ok(())` when `name` is a recognised colormap, otherwise returns an
19/// error string listing the valid choices.
20pub fn validate_colormap(name: &str) -> Result<(), String> {
21    if VALID_COLORMAPS.contains(&name) {
22        Ok(())
23    } else {
24        Err(format!(
25            "colormap: '{}' is not a recognised colormap. Valid colormaps: {}",
26            name,
27            VALID_COLORMAPS.join(", ")
28        ))
29    }
30}
31
32/// A colormap specification: either a built-in named colormap or a custom
33/// N×3 look-up table supplied by the user.
34#[derive(Clone, Debug, PartialEq)]
35pub enum ColormapSpec {
36    /// One of the built-in named colormaps (e.g. `"viridis"`, `"hot"`).
37    ///
38    /// Valid names are listed in [`VALID_COLORMAPS`].
39    Named(String),
40    /// Custom LUT: a vector of `(R, G, B)` triplets (at least two entries).
41    ///
42    /// Component values are in `[0, 255]`; entries are linearly interpolated.
43    Custom(Vec<(u8, u8, u8)>),
44}
45
46/// Maps a normalised value `t ∈ [0, 1]` to an `(R, G, B)` triple.
47///
48/// Values outside `[0, 1]` are clamped.  Unrecognised names fall back to
49/// `"viridis"`.
50///
51/// # Examples
52///
53/// ```
54/// use ccalc_plot::colormap::apply_colormap;
55/// let (r, g, b) = apply_colormap(0.0, "gray");
56/// assert_eq!((r, g, b), (0, 0, 0));
57/// let (r, g, b) = apply_colormap(1.0, "gray");
58/// assert_eq!((r, g, b), (255, 255, 255));
59/// ```
60pub fn apply_colormap(t: f64, name: &str) -> (u8, u8, u8) {
61    let t = t.clamp(0.0, 1.0);
62    match name {
63        "viridis" => lut_lerp(t, &VIRIDIS),
64        "inferno" => lut_lerp(t, &INFERNO),
65        "magma" => lut_lerp(t, &MAGMA),
66        "plasma" => lut_lerp(t, &PLASMA),
67        "hot" => lut_lerp(t, &HOT),
68        "cool" => lut_lerp(t, &COOL),
69        "jet" => lut_lerp(t, &JET),
70        "gray" => {
71            let v = (t * 255.0).round() as u8;
72            (v, v, v)
73        }
74        _ => lut_lerp(t, &VIRIDIS),
75    }
76}
77
78/// Maps a normalised value `t ∈ [0, 1]` to an `(R, G, B)` triple using `spec`.
79///
80/// Delegates to [`apply_colormap`] for [`ColormapSpec::Named`] and to the
81/// built-in LUT interpolator for [`ColormapSpec::Custom`].
82///
83/// # Examples
84///
85/// ```
86/// use ccalc_plot::colormap::{apply_colormap_spec, ColormapSpec};
87/// let spec = ColormapSpec::Named("gray".to_string());
88/// assert_eq!(apply_colormap_spec(0.0, &spec), (0, 0, 0));
89/// assert_eq!(apply_colormap_spec(1.0, &spec), (255, 255, 255));
90/// ```
91pub fn apply_colormap_spec(t: f64, spec: &ColormapSpec) -> (u8, u8, u8) {
92    match spec {
93        ColormapSpec::Named(name) => apply_colormap(t, name),
94        ColormapSpec::Custom(lut) => lut_lerp(t, lut),
95    }
96}
97
98/// Validates a [`ColormapSpec`], returning an error string on failure.
99///
100/// Named variants are checked against [`VALID_COLORMAPS`].  Custom variants
101/// require at least two LUT entries.
102pub fn validate_colormap_spec(spec: &ColormapSpec) -> Result<(), String> {
103    match spec {
104        ColormapSpec::Named(name) => validate_colormap(name),
105        ColormapSpec::Custom(lut) => {
106            if lut.len() < 2 {
107                Err("colormap: custom colormap must have at least 2 rows".into())
108            } else {
109                Ok(())
110            }
111        }
112    }
113}
114
115// ── LUT interpolation ──────────────────────────────────────────────────────
116
117fn lut_lerp(t: f64, lut: &[(u8, u8, u8)]) -> (u8, u8, u8) {
118    let n = lut.len();
119    if n == 1 {
120        return lut[0];
121    }
122    let ts = t * (n - 1) as f64;
123    let lo = (ts as usize).min(n - 2);
124    let hi = lo + 1;
125    let f = ts - lo as f64;
126    let lerp = |a: u8, b: u8| (a as f64 + f * (b as f64 - a as f64)).round() as u8;
127    (
128        lerp(lut[lo].0, lut[hi].0),
129        lerp(lut[lo].1, lut[hi].1),
130        lerp(lut[lo].2, lut[hi].2),
131    )
132}
133
134// ── LUT data ───────────────────────────────────────────────────────────────
135
136const VIRIDIS: [(u8, u8, u8); 8] = [
137    (68, 1, 84),
138    (72, 40, 120),
139    (62, 83, 160),
140    (49, 104, 142),
141    (53, 183, 121),
142    (101, 203, 94),
143    (180, 222, 44),
144    (253, 231, 37),
145];
146const INFERNO: [(u8, u8, u8); 8] = [
147    (0, 0, 4),
148    (40, 11, 84),
149    (101, 21, 110),
150    (159, 42, 99),
151    (212, 72, 66),
152    (245, 125, 21),
153    (252, 190, 44),
154    (252, 255, 164),
155];
156const MAGMA: [(u8, u8, u8); 8] = [
157    (0, 0, 4),
158    (28, 16, 68),
159    (79, 18, 123),
160    (129, 37, 129),
161    (181, 55, 122),
162    (229, 89, 104),
163    (251, 143, 107),
164    (252, 253, 191),
165];
166const PLASMA: [(u8, u8, u8); 8] = [
167    (13, 8, 135),
168    (84, 2, 163),
169    (139, 10, 165),
170    (185, 50, 137),
171    (219, 92, 104),
172    (243, 135, 72),
173    (253, 182, 44),
174    (240, 249, 33),
175];
176const HOT: [(u8, u8, u8); 8] = [
177    (0, 0, 0),
178    (96, 0, 0),
179    (192, 0, 0),
180    (255, 48, 0),
181    (255, 144, 0),
182    (255, 216, 0),
183    (255, 255, 96),
184    (255, 255, 255),
185];
186const COOL: [(u8, u8, u8); 8] = [
187    (0, 255, 255),
188    (36, 219, 255),
189    (73, 182, 255),
190    (109, 146, 255),
191    (146, 109, 255),
192    (182, 73, 255),
193    (219, 36, 255),
194    (255, 0, 255),
195];
196const JET: [(u8, u8, u8); 8] = [
197    (0, 0, 143),
198    (0, 0, 255),
199    (0, 218, 255),
200    (0, 255, 36),
201    (146, 255, 0),
202    (255, 218, 0),
203    (255, 36, 0),
204    (143, 0, 0),
205];
206
207// ── Data helpers ───────────────────────────────────────────────────────────
208
209/// Returns `(min, max)` of finite values in `z`.  Falls back to `(0, 1)` on
210/// all-NaN input; expands a degenerate range by 1.
211#[cfg(any(feature = "plot", feature = "plot-svg"))]
212pub(crate) fn data_range(z: &[f64]) -> (f64, f64) {
213    let mut lo = f64::INFINITY;
214    let mut hi = f64::NEG_INFINITY;
215    for &v in z {
216        if v.is_finite() {
217            lo = lo.min(v);
218            hi = hi.max(v);
219        }
220    }
221    if !lo.is_finite() {
222        lo = 0.0;
223        hi = 1.0;
224    }
225    if (hi - lo).abs() < f64::EPSILON {
226        hi = lo + 1.0;
227    }
228    (lo, hi)
229}
230
231// ── ASCII renderer ─────────────────────────────────────────────────────────
232
233/// Renders `imagesc` as character art to stdout.
234///
235/// Uses a 10-level density palette `" .:-=+*#@█"` to approximate intensity.
236/// A one-line colorbar showing the data range is appended when
237/// `state.colorbar` is `true`.
238#[cfg(feature = "plot")]
239pub fn render_imagesc_ascii(z: &[f64], nrows: usize, ncols: usize, state: &FigureState) {
240    const DENSITY: [char; 10] = [' ', '.', ':', '-', '=', '+', '*', '#', '@', '█'];
241
242    if nrows == 0 || ncols == 0 {
243        return;
244    }
245
246    let (z_min, z_max) = data_range(z);
247    let range = z_max - z_min;
248
249    if let Some(t) = &state.title {
250        println!("{t}");
251    }
252
253    for r in 0..nrows {
254        for c in 0..ncols {
255            let v = z[r * ncols + c];
256            let t = if range > 0.0 {
257                ((v - z_min) / range).clamp(0.0, 1.0)
258            } else {
259                0.5
260            };
261            let idx = ((t * 9.0) as usize).min(9);
262            print!("{}", DENSITY[idx]);
263        }
264        println!();
265    }
266
267    if state.colorbar {
268        let steps = 20_usize;
269        let gradient: String = (0..steps)
270            .map(|i| {
271                let t = i as f64 / (steps - 1).max(1) as f64;
272                let idx = ((t * 9.0) as usize).min(9);
273                DENSITY[idx]
274            })
275            .collect();
276        println!("{z_min:.4} [{gradient}] {z_max:.4}");
277    }
278    if let Some(xl) = &state.xlabel {
279        println!("x: {xl}");
280    }
281    if let Some(yl) = &state.ylabel {
282        println!("y: {yl}");
283    }
284}
285
286// ── SVG/PNG file renderer ──────────────────────────────────────────────────
287
288/// Width reserved for the colorbar strip (pixels).
289#[cfg(feature = "plot-svg")]
290const CB_WIDTH: u32 = 80;
291
292/// Writes a false-colour image of `z` to an SVG or PNG file.
293///
294/// The active colormap is taken from `state.colormap` (default `"viridis"`).
295/// If `state.colorbar` is `true`, a gradient strip with value labels is
296/// appended on the right side of the image.
297/// Canvas size is taken from [`FigureState::canvas_size`] (default 800 × 600).
298#[cfg(feature = "plot-svg")]
299pub fn render_imagesc_file(
300    z: &[f64],
301    nrows: usize,
302    ncols: usize,
303    path: &str,
304    state: FigureState,
305) -> Result<(), String> {
306    let (width, height) = state.canvas_size();
307    if path.ends_with(".svg") {
308        let root = SVGBackend::new(path, (width, height)).into_drawing_area();
309        draw_imagesc(z, nrows, ncols, &state, root, width)
310    } else if path.ends_with(".png") {
311        let root = BitMapBackend::new(path, (width, height)).into_drawing_area();
312        draw_imagesc(z, nrows, ncols, &state, root, width)
313    } else {
314        Err(format!("imagesc: unsupported format '{path}'"))
315    }
316}
317
318#[cfg(feature = "plot-svg")]
319fn draw_imagesc<DB: DrawingBackend>(
320    z: &[f64],
321    nrows: usize,
322    ncols: usize,
323    state: &FigureState,
324    root: DrawingArea<DB, plotters::coord::Shift>,
325    width: u32,
326) -> Result<(), String>
327where
328    DB::ErrorType: std::fmt::Display,
329{
330    let (r, g, b) = state.effective_bg_rgb();
331    root.fill(&RGBColor(r, g, b)).map_err(|e| e.to_string())?;
332
333    if nrows == 0 || ncols == 0 {
334        return root.present().map_err(|e| e.to_string());
335    }
336
337    let default_spec = ColormapSpec::Named("viridis".to_string());
338    let cmap_spec = state.colormap.as_ref().unwrap_or(&default_spec);
339    let (z_min, z_max) = data_range(z);
340    let range = z_max - z_min;
341
342    if state.colorbar {
343        let split = (width.saturating_sub(CB_WIDTH)) as i32;
344        let (img_area, cb_area) = root.split_horizontally(split);
345        draw_imagesc_cells(&img_area, z, nrows, ncols, state, cmap_spec, z_min, range)?;
346        draw_colorbar(&cb_area, z_min, z_max, cmap_spec)?;
347    } else {
348        draw_imagesc_cells(&root, z, nrows, ncols, state, cmap_spec, z_min, range)?;
349    }
350
351    root.present().map_err(|e| e.to_string())?;
352    Ok(())
353}
354
355#[cfg(feature = "plot-svg")]
356#[allow(clippy::too_many_arguments)]
357fn draw_imagesc_cells<DB: DrawingBackend>(
358    area: &DrawingArea<DB, plotters::coord::Shift>,
359    z: &[f64],
360    nrows: usize,
361    ncols: usize,
362    state: &FigureState,
363    spec: &ColormapSpec,
364    z_min: f64,
365    range: f64,
366) -> Result<(), String>
367where
368    DB::ErrorType: std::fmt::Display,
369{
370    let title = state.title.as_deref().unwrap_or("");
371    let xlabel = state.xlabel.as_deref().unwrap_or("");
372    let ylabel = state.ylabel.as_deref().unwrap_or("");
373
374    let mut chart = ChartBuilder::on(area)
375        .caption(title, ("sans-serif", 20))
376        .margin(30)
377        .x_label_area_size(40)
378        .y_label_area_size(50)
379        .build_cartesian_2d(0.0..(ncols as f64), 0.0..(nrows as f64))
380        .map_err(|e| e.to_string())?;
381
382    chart
383        .configure_mesh()
384        .x_desc(xlabel)
385        .y_desc(ylabel)
386        .disable_mesh()
387        .draw()
388        .map_err(|e| e.to_string())?;
389
390    // Row 0 of Z is the top row; map it to y ∈ [nrows-1, nrows].
391    for r in 0..nrows {
392        let y_lo = (nrows - 1 - r) as f64;
393        let y_hi = y_lo + 1.0;
394        for c in 0..ncols {
395            let v = z[r * ncols + c];
396            let t = if range > 0.0 {
397                ((v - z_min) / range).clamp(0.0, 1.0)
398            } else {
399                0.5
400            };
401            let (rr, gg, bb) = apply_colormap_spec(t, spec);
402            chart
403                .draw_series(std::iter::once(Rectangle::new(
404                    [(c as f64, y_lo), ((c + 1) as f64, y_hi)],
405                    RGBColor(rr, gg, bb).filled(),
406                )))
407                .map_err(|e| e.to_string())?;
408        }
409    }
410    Ok(())
411}
412
413#[cfg(feature = "plot-svg")]
414fn draw_colorbar<DB: DrawingBackend>(
415    area: &DrawingArea<DB, plotters::coord::Shift>,
416    z_min: f64,
417    z_max: f64,
418    spec: &ColormapSpec,
419) -> Result<(), String>
420where
421    DB::ErrorType: std::fmt::Display,
422{
423    let n_steps: usize = 64;
424    let step_h = (z_max - z_min) / n_steps as f64;
425
426    // Horizontal margins must be small: CB_WIDTH = 80 px, y_label_area = 40 px.
427    // margin_left=0 + margin_right=4 + y_label_area=40 → 36 px for the gradient strip.
428    let mut chart = ChartBuilder::on(area)
429        .margin_top(30)
430        .margin_bottom(30)
431        .margin_left(0)
432        .margin_right(4)
433        .x_label_area_size(0)
434        .y_label_area_size(40)
435        .build_cartesian_2d(0.0..1.0, z_min..z_max)
436        .map_err(|e| e.to_string())?;
437
438    // Draw the axis ticks / labels first (fills chart area with white background).
439    chart
440        .configure_mesh()
441        .disable_x_mesh()
442        .disable_y_mesh()
443        .draw()
444        .map_err(|e| e.to_string())?;
445
446    // Draw gradient on top of the white background.
447    chart
448        .draw_series((0..n_steps).map(|i| {
449            let t = i as f64 / (n_steps - 1).max(1) as f64;
450            let y_lo = z_min + i as f64 * step_h;
451            let y_hi = (y_lo + step_h).min(z_max);
452            let (r, g, b) = apply_colormap_spec(t, spec);
453            Rectangle::new([(0.0, y_lo), (1.0, y_hi)], RGBColor(r, g, b).filled())
454        }))
455        .map_err(|e| e.to_string())?;
456
457    Ok(())
458}
459
460// ── Tests ──────────────────────────────────────────────────────────────────
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465
466    #[test]
467    fn test_apply_colormap_gray_extremes() {
468        assert_eq!(apply_colormap(0.0, "gray"), (0, 0, 0));
469        assert_eq!(apply_colormap(1.0, "gray"), (255, 255, 255));
470    }
471
472    #[test]
473    fn test_colormap_custom_2pt() {
474        let lut = vec![(0u8, 0, 0), (255u8, 255, 255)];
475        let spec = ColormapSpec::Custom(lut);
476        assert_eq!(apply_colormap_spec(0.0, &spec), (0, 0, 0));
477        assert_eq!(apply_colormap_spec(1.0, &spec), (255, 255, 255));
478    }
479
480    #[test]
481    fn test_colormap_custom_midpt() {
482        let lut = vec![(0u8, 0, 0), (200u8, 100, 50)];
483        let spec = ColormapSpec::Custom(lut);
484        let (r, g, b) = apply_colormap_spec(0.5, &spec);
485        assert_eq!(r, 100);
486        assert_eq!(g, 50);
487        assert_eq!(b, 25);
488    }
489
490    #[test]
491    fn test_colormap_custom_too_short() {
492        let spec = ColormapSpec::Custom(vec![(128u8, 0, 0)]);
493        assert!(validate_colormap_spec(&spec).is_err());
494    }
495
496    #[test]
497    fn test_colormap_spec_named_viridis() {
498        let spec = ColormapSpec::Named("viridis".to_string());
499        assert!(validate_colormap_spec(&spec).is_ok());
500        assert_eq!(
501            apply_colormap_spec(0.0, &spec),
502            apply_colormap(0.0, "viridis")
503        );
504        assert_eq!(
505            apply_colormap_spec(1.0, &spec),
506            apply_colormap(1.0, "viridis")
507        );
508    }
509
510    #[test]
511    fn test_apply_colormap_clamp() {
512        // Values outside [0,1] are clamped, not panicked.
513        let lo = apply_colormap(-1.0, "hot");
514        let hi = apply_colormap(2.0, "hot");
515        assert_eq!(lo, apply_colormap(0.0, "hot"));
516        assert_eq!(hi, apply_colormap(1.0, "hot"));
517    }
518
519    #[test]
520    fn test_apply_colormap_fallback() {
521        // Unknown colormap falls back to viridis — no panic.
522        let _ = apply_colormap(0.5, "unknown_colormap_xyz");
523    }
524
525    #[test]
526    fn test_validate_colormap_valid() {
527        for name in VALID_COLORMAPS {
528            assert!(validate_colormap(name).is_ok(), "'{name}' should be valid");
529        }
530    }
531
532    #[test]
533    fn test_validate_colormap_invalid() {
534        let result = validate_colormap("rainbow");
535        assert!(result.is_err());
536        let msg = result.unwrap_err();
537        assert!(
538            msg.contains("colormap"),
539            "error should mention colormap: {msg}"
540        );
541    }
542
543    #[cfg(any(feature = "plot", feature = "plot-svg"))]
544    #[test]
545    fn test_data_range_normal() {
546        let (lo, hi) = data_range(&[3.0, 1.0, 4.0, 1.5]);
547        assert!((lo - 1.0).abs() < 1e-9);
548        assert!((hi - 4.0).abs() < 1e-9);
549    }
550
551    #[cfg(any(feature = "plot", feature = "plot-svg"))]
552    #[test]
553    fn test_data_range_all_nan() {
554        let (lo, hi) = data_range(&[f64::NAN]);
555        assert_eq!((lo, hi), (0.0, 1.0));
556    }
557
558    #[cfg(any(feature = "plot", feature = "plot-svg"))]
559    #[test]
560    fn test_data_range_constant() {
561        // Constant input gets expanded so range > 0.
562        let (lo, hi) = data_range(&[5.0, 5.0, 5.0]);
563        assert!(hi > lo);
564    }
565}