1#[cfg(feature = "plot-svg")]
4use plotters::prelude::*;
5
6#[cfg(any(feature = "plot", feature = "plot-svg"))]
7use crate::FigureState;
8
9pub const VALID_COLORMAPS: &[&str] = &[
13 "viridis", "inferno", "magma", "plasma", "hot", "cool", "jet", "gray",
14];
15
16pub fn validate_colormap(name: &str) -> Result<(), String> {
21 if VALID_COLORMAPS.contains(&name) {
22 Ok(())
23 } else {
24 Err(format!(
25 "colormap: '{}' is not a recognised colormap. Valid colormaps: {}",
26 name,
27 VALID_COLORMAPS.join(", ")
28 ))
29 }
30}
31
32#[derive(Clone, Debug, PartialEq)]
35pub enum ColormapSpec {
36 Named(String),
40 Custom(Vec<(u8, u8, u8)>),
44}
45
46pub fn apply_colormap(t: f64, name: &str) -> (u8, u8, u8) {
61 let t = t.clamp(0.0, 1.0);
62 match name {
63 "viridis" => lut_lerp(t, &VIRIDIS),
64 "inferno" => lut_lerp(t, &INFERNO),
65 "magma" => lut_lerp(t, &MAGMA),
66 "plasma" => lut_lerp(t, &PLASMA),
67 "hot" => lut_lerp(t, &HOT),
68 "cool" => lut_lerp(t, &COOL),
69 "jet" => lut_lerp(t, &JET),
70 "gray" => {
71 let v = (t * 255.0).round() as u8;
72 (v, v, v)
73 }
74 _ => lut_lerp(t, &VIRIDIS),
75 }
76}
77
78pub fn apply_colormap_spec(t: f64, spec: &ColormapSpec) -> (u8, u8, u8) {
92 match spec {
93 ColormapSpec::Named(name) => apply_colormap(t, name),
94 ColormapSpec::Custom(lut) => lut_lerp(t, lut),
95 }
96}
97
98pub fn validate_colormap_spec(spec: &ColormapSpec) -> Result<(), String> {
103 match spec {
104 ColormapSpec::Named(name) => validate_colormap(name),
105 ColormapSpec::Custom(lut) => {
106 if lut.len() < 2 {
107 Err("colormap: custom colormap must have at least 2 rows".into())
108 } else {
109 Ok(())
110 }
111 }
112 }
113}
114
115fn lut_lerp(t: f64, lut: &[(u8, u8, u8)]) -> (u8, u8, u8) {
118 let n = lut.len();
119 if n == 1 {
120 return lut[0];
121 }
122 let ts = t * (n - 1) as f64;
123 let lo = (ts as usize).min(n - 2);
124 let hi = lo + 1;
125 let f = ts - lo as f64;
126 let lerp = |a: u8, b: u8| (a as f64 + f * (b as f64 - a as f64)).round() as u8;
127 (
128 lerp(lut[lo].0, lut[hi].0),
129 lerp(lut[lo].1, lut[hi].1),
130 lerp(lut[lo].2, lut[hi].2),
131 )
132}
133
134const VIRIDIS: [(u8, u8, u8); 8] = [
137 (68, 1, 84),
138 (72, 40, 120),
139 (62, 83, 160),
140 (49, 104, 142),
141 (53, 183, 121),
142 (101, 203, 94),
143 (180, 222, 44),
144 (253, 231, 37),
145];
146const INFERNO: [(u8, u8, u8); 8] = [
147 (0, 0, 4),
148 (40, 11, 84),
149 (101, 21, 110),
150 (159, 42, 99),
151 (212, 72, 66),
152 (245, 125, 21),
153 (252, 190, 44),
154 (252, 255, 164),
155];
156const MAGMA: [(u8, u8, u8); 8] = [
157 (0, 0, 4),
158 (28, 16, 68),
159 (79, 18, 123),
160 (129, 37, 129),
161 (181, 55, 122),
162 (229, 89, 104),
163 (251, 143, 107),
164 (252, 253, 191),
165];
166const PLASMA: [(u8, u8, u8); 8] = [
167 (13, 8, 135),
168 (84, 2, 163),
169 (139, 10, 165),
170 (185, 50, 137),
171 (219, 92, 104),
172 (243, 135, 72),
173 (253, 182, 44),
174 (240, 249, 33),
175];
176const HOT: [(u8, u8, u8); 8] = [
177 (0, 0, 0),
178 (96, 0, 0),
179 (192, 0, 0),
180 (255, 48, 0),
181 (255, 144, 0),
182 (255, 216, 0),
183 (255, 255, 96),
184 (255, 255, 255),
185];
186const COOL: [(u8, u8, u8); 8] = [
187 (0, 255, 255),
188 (36, 219, 255),
189 (73, 182, 255),
190 (109, 146, 255),
191 (146, 109, 255),
192 (182, 73, 255),
193 (219, 36, 255),
194 (255, 0, 255),
195];
196const JET: [(u8, u8, u8); 8] = [
197 (0, 0, 143),
198 (0, 0, 255),
199 (0, 218, 255),
200 (0, 255, 36),
201 (146, 255, 0),
202 (255, 218, 0),
203 (255, 36, 0),
204 (143, 0, 0),
205];
206
207#[cfg(any(feature = "plot", feature = "plot-svg"))]
212pub(crate) fn data_range(z: &[f64]) -> (f64, f64) {
213 let mut lo = f64::INFINITY;
214 let mut hi = f64::NEG_INFINITY;
215 for &v in z {
216 if v.is_finite() {
217 lo = lo.min(v);
218 hi = hi.max(v);
219 }
220 }
221 if !lo.is_finite() {
222 lo = 0.0;
223 hi = 1.0;
224 }
225 if (hi - lo).abs() < f64::EPSILON {
226 hi = lo + 1.0;
227 }
228 (lo, hi)
229}
230
231#[cfg(feature = "plot")]
239pub fn render_imagesc_ascii(z: &[f64], nrows: usize, ncols: usize, state: &FigureState) {
240 const DENSITY: [char; 10] = [' ', '.', ':', '-', '=', '+', '*', '#', '@', '█'];
241
242 if nrows == 0 || ncols == 0 {
243 return;
244 }
245
246 let (z_min, z_max) = data_range(z);
247 let range = z_max - z_min;
248
249 if let Some(t) = &state.title {
250 println!("{t}");
251 }
252
253 for r in 0..nrows {
254 for c in 0..ncols {
255 let v = z[r * ncols + c];
256 let t = if range > 0.0 {
257 ((v - z_min) / range).clamp(0.0, 1.0)
258 } else {
259 0.5
260 };
261 let idx = ((t * 9.0) as usize).min(9);
262 print!("{}", DENSITY[idx]);
263 }
264 println!();
265 }
266
267 if state.colorbar {
268 let steps = 20_usize;
269 let gradient: String = (0..steps)
270 .map(|i| {
271 let t = i as f64 / (steps - 1).max(1) as f64;
272 let idx = ((t * 9.0) as usize).min(9);
273 DENSITY[idx]
274 })
275 .collect();
276 println!("{z_min:.4} [{gradient}] {z_max:.4}");
277 }
278 if let Some(xl) = &state.xlabel {
279 println!("x: {xl}");
280 }
281 if let Some(yl) = &state.ylabel {
282 println!("y: {yl}");
283 }
284}
285
286#[cfg(feature = "plot-svg")]
290const CB_WIDTH: u32 = 80;
291
292#[cfg(feature = "plot-svg")]
299pub fn render_imagesc_file(
300 z: &[f64],
301 nrows: usize,
302 ncols: usize,
303 path: &str,
304 state: FigureState,
305) -> Result<(), String> {
306 let (width, height) = state.canvas_size();
307 if path.ends_with(".svg") {
308 let root = SVGBackend::new(path, (width, height)).into_drawing_area();
309 draw_imagesc(z, nrows, ncols, &state, root, width)
310 } else if path.ends_with(".png") {
311 let root = BitMapBackend::new(path, (width, height)).into_drawing_area();
312 draw_imagesc(z, nrows, ncols, &state, root, width)
313 } else {
314 Err(format!("imagesc: unsupported format '{path}'"))
315 }
316}
317
318#[cfg(feature = "plot-svg")]
319fn draw_imagesc<DB: DrawingBackend>(
320 z: &[f64],
321 nrows: usize,
322 ncols: usize,
323 state: &FigureState,
324 root: DrawingArea<DB, plotters::coord::Shift>,
325 width: u32,
326) -> Result<(), String>
327where
328 DB::ErrorType: std::fmt::Display,
329{
330 let (r, g, b) = state.effective_bg_rgb();
331 root.fill(&RGBColor(r, g, b)).map_err(|e| e.to_string())?;
332
333 if nrows == 0 || ncols == 0 {
334 return root.present().map_err(|e| e.to_string());
335 }
336
337 let default_spec = ColormapSpec::Named("viridis".to_string());
338 let cmap_spec = state.colormap.as_ref().unwrap_or(&default_spec);
339 let (z_min, z_max) = data_range(z);
340 let range = z_max - z_min;
341
342 if state.colorbar {
343 let split = (width.saturating_sub(CB_WIDTH)) as i32;
344 let (img_area, cb_area) = root.split_horizontally(split);
345 draw_imagesc_cells(&img_area, z, nrows, ncols, state, cmap_spec, z_min, range)?;
346 draw_colorbar(&cb_area, z_min, z_max, cmap_spec)?;
347 } else {
348 draw_imagesc_cells(&root, z, nrows, ncols, state, cmap_spec, z_min, range)?;
349 }
350
351 root.present().map_err(|e| e.to_string())?;
352 Ok(())
353}
354
355#[cfg(feature = "plot-svg")]
356#[allow(clippy::too_many_arguments)]
357fn draw_imagesc_cells<DB: DrawingBackend>(
358 area: &DrawingArea<DB, plotters::coord::Shift>,
359 z: &[f64],
360 nrows: usize,
361 ncols: usize,
362 state: &FigureState,
363 spec: &ColormapSpec,
364 z_min: f64,
365 range: f64,
366) -> Result<(), String>
367where
368 DB::ErrorType: std::fmt::Display,
369{
370 let title = state.title.as_deref().unwrap_or("");
371 let xlabel = state.xlabel.as_deref().unwrap_or("");
372 let ylabel = state.ylabel.as_deref().unwrap_or("");
373
374 let mut chart = ChartBuilder::on(area)
375 .caption(title, ("sans-serif", 20))
376 .margin(30)
377 .x_label_area_size(40)
378 .y_label_area_size(50)
379 .build_cartesian_2d(0.0..(ncols as f64), 0.0..(nrows as f64))
380 .map_err(|e| e.to_string())?;
381
382 chart
383 .configure_mesh()
384 .x_desc(xlabel)
385 .y_desc(ylabel)
386 .disable_mesh()
387 .draw()
388 .map_err(|e| e.to_string())?;
389
390 for r in 0..nrows {
392 let y_lo = (nrows - 1 - r) as f64;
393 let y_hi = y_lo + 1.0;
394 for c in 0..ncols {
395 let v = z[r * ncols + c];
396 let t = if range > 0.0 {
397 ((v - z_min) / range).clamp(0.0, 1.0)
398 } else {
399 0.5
400 };
401 let (rr, gg, bb) = apply_colormap_spec(t, spec);
402 chart
403 .draw_series(std::iter::once(Rectangle::new(
404 [(c as f64, y_lo), ((c + 1) as f64, y_hi)],
405 RGBColor(rr, gg, bb).filled(),
406 )))
407 .map_err(|e| e.to_string())?;
408 }
409 }
410 Ok(())
411}
412
413#[cfg(feature = "plot-svg")]
414fn draw_colorbar<DB: DrawingBackend>(
415 area: &DrawingArea<DB, plotters::coord::Shift>,
416 z_min: f64,
417 z_max: f64,
418 spec: &ColormapSpec,
419) -> Result<(), String>
420where
421 DB::ErrorType: std::fmt::Display,
422{
423 let n_steps: usize = 64;
424 let step_h = (z_max - z_min) / n_steps as f64;
425
426 let mut chart = ChartBuilder::on(area)
429 .margin_top(30)
430 .margin_bottom(30)
431 .margin_left(0)
432 .margin_right(4)
433 .x_label_area_size(0)
434 .y_label_area_size(40)
435 .build_cartesian_2d(0.0..1.0, z_min..z_max)
436 .map_err(|e| e.to_string())?;
437
438 chart
440 .configure_mesh()
441 .disable_x_mesh()
442 .disable_y_mesh()
443 .draw()
444 .map_err(|e| e.to_string())?;
445
446 chart
448 .draw_series((0..n_steps).map(|i| {
449 let t = i as f64 / (n_steps - 1).max(1) as f64;
450 let y_lo = z_min + i as f64 * step_h;
451 let y_hi = (y_lo + step_h).min(z_max);
452 let (r, g, b) = apply_colormap_spec(t, spec);
453 Rectangle::new([(0.0, y_lo), (1.0, y_hi)], RGBColor(r, g, b).filled())
454 }))
455 .map_err(|e| e.to_string())?;
456
457 Ok(())
458}
459
460#[cfg(test)]
463mod tests {
464 use super::*;
465
466 #[test]
467 fn test_apply_colormap_gray_extremes() {
468 assert_eq!(apply_colormap(0.0, "gray"), (0, 0, 0));
469 assert_eq!(apply_colormap(1.0, "gray"), (255, 255, 255));
470 }
471
472 #[test]
473 fn test_colormap_custom_2pt() {
474 let lut = vec![(0u8, 0, 0), (255u8, 255, 255)];
475 let spec = ColormapSpec::Custom(lut);
476 assert_eq!(apply_colormap_spec(0.0, &spec), (0, 0, 0));
477 assert_eq!(apply_colormap_spec(1.0, &spec), (255, 255, 255));
478 }
479
480 #[test]
481 fn test_colormap_custom_midpt() {
482 let lut = vec![(0u8, 0, 0), (200u8, 100, 50)];
483 let spec = ColormapSpec::Custom(lut);
484 let (r, g, b) = apply_colormap_spec(0.5, &spec);
485 assert_eq!(r, 100);
486 assert_eq!(g, 50);
487 assert_eq!(b, 25);
488 }
489
490 #[test]
491 fn test_colormap_custom_too_short() {
492 let spec = ColormapSpec::Custom(vec![(128u8, 0, 0)]);
493 assert!(validate_colormap_spec(&spec).is_err());
494 }
495
496 #[test]
497 fn test_colormap_spec_named_viridis() {
498 let spec = ColormapSpec::Named("viridis".to_string());
499 assert!(validate_colormap_spec(&spec).is_ok());
500 assert_eq!(
501 apply_colormap_spec(0.0, &spec),
502 apply_colormap(0.0, "viridis")
503 );
504 assert_eq!(
505 apply_colormap_spec(1.0, &spec),
506 apply_colormap(1.0, "viridis")
507 );
508 }
509
510 #[test]
511 fn test_apply_colormap_clamp() {
512 let lo = apply_colormap(-1.0, "hot");
514 let hi = apply_colormap(2.0, "hot");
515 assert_eq!(lo, apply_colormap(0.0, "hot"));
516 assert_eq!(hi, apply_colormap(1.0, "hot"));
517 }
518
519 #[test]
520 fn test_apply_colormap_fallback() {
521 let _ = apply_colormap(0.5, "unknown_colormap_xyz");
523 }
524
525 #[test]
526 fn test_validate_colormap_valid() {
527 for name in VALID_COLORMAPS {
528 assert!(validate_colormap(name).is_ok(), "'{name}' should be valid");
529 }
530 }
531
532 #[test]
533 fn test_validate_colormap_invalid() {
534 let result = validate_colormap("rainbow");
535 assert!(result.is_err());
536 let msg = result.unwrap_err();
537 assert!(
538 msg.contains("colormap"),
539 "error should mention colormap: {msg}"
540 );
541 }
542
543 #[cfg(any(feature = "plot", feature = "plot-svg"))]
544 #[test]
545 fn test_data_range_normal() {
546 let (lo, hi) = data_range(&[3.0, 1.0, 4.0, 1.5]);
547 assert!((lo - 1.0).abs() < 1e-9);
548 assert!((hi - 4.0).abs() < 1e-9);
549 }
550
551 #[cfg(any(feature = "plot", feature = "plot-svg"))]
552 #[test]
553 fn test_data_range_all_nan() {
554 let (lo, hi) = data_range(&[f64::NAN]);
555 assert_eq!((lo, hi), (0.0, 1.0));
556 }
557
558 #[cfg(any(feature = "plot", feature = "plot-svg"))]
559 #[test]
560 fn test_data_range_constant() {
561 let (lo, hi) = data_range(&[5.0, 5.0, 5.0]);
563 assert!(hi > lo);
564 }
565}