gpui_px/
color_scale.rs

1//! Color scales for heatmaps, contours, and isolines.
2
3use d3rs::color::D3Color;
4use std::sync::Arc;
5
6/// Color scale for 2D visualizations (heatmaps, contours).
7#[derive(Clone, Default)]
8pub enum ColorScale {
9    /// Viridis - perceptually uniform, colorblind-friendly (purple → yellow).
10    #[default]
11    Viridis,
12    /// Plasma - perceptually uniform (purple → orange → yellow).
13    Plasma,
14    /// Inferno - perceptually uniform (black → purple → orange → yellow).
15    Inferno,
16    /// Magma - perceptually uniform (black → purple → orange → white).
17    Magma,
18    /// Heat - diverging (blue → white → red).
19    Heat,
20    /// Coolwarm - diverging (cool blue → neutral → warm red).
21    Coolwarm,
22    /// Greys - sequential grayscale (white → black).
23    Greys,
24    /// Custom color scale function.
25    Custom(Arc<dyn Fn(f64) -> D3Color + Send + Sync>),
26}
27
28impl std::fmt::Debug for ColorScale {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        match self {
31            ColorScale::Viridis => write!(f, "ColorScale::Viridis"),
32            ColorScale::Plasma => write!(f, "ColorScale::Plasma"),
33            ColorScale::Inferno => write!(f, "ColorScale::Inferno"),
34            ColorScale::Magma => write!(f, "ColorScale::Magma"),
35            ColorScale::Heat => write!(f, "ColorScale::Heat"),
36            ColorScale::Coolwarm => write!(f, "ColorScale::Coolwarm"),
37            ColorScale::Greys => write!(f, "ColorScale::Greys"),
38            ColorScale::Custom(_) => write!(f, "ColorScale::Custom(...)"),
39        }
40    }
41}
42
43impl ColorScale {
44    /// Create a custom color scale from a function.
45    ///
46    /// The function should map values in [0, 1] to colors.
47    pub fn custom<F>(f: F) -> Self
48    where
49        F: Fn(f64) -> D3Color + Send + Sync + 'static,
50    {
51        ColorScale::Custom(Arc::new(f))
52    }
53
54    /// Convert to a function that maps [0, 1] → D3Color.
55    pub fn to_fn(&self) -> impl Fn(f64) -> D3Color + Send + Sync + Clone + 'static {
56        let scale = self.clone();
57        move |t: f64| scale.map(t)
58    }
59
60    /// Map a value in [0, 1] to a color.
61    pub fn map(&self, t: f64) -> D3Color {
62        let t = t.clamp(0.0, 1.0);
63
64        match self {
65            ColorScale::Viridis => viridis(t),
66            ColorScale::Plasma => plasma(t),
67            ColorScale::Inferno => inferno(t),
68            ColorScale::Magma => magma(t),
69            ColorScale::Heat => heat(t),
70            ColorScale::Coolwarm => coolwarm(t),
71            ColorScale::Greys => greys(t),
72            ColorScale::Custom(f) => f(t),
73        }
74    }
75}
76
77// Helper function to interpolate between colors in a palette
78fn interpolate_palette(t: f64, colors: &[D3Color]) -> D3Color {
79    let idx = (t * (colors.len() - 1) as f64) as usize;
80    let idx = idx.min(colors.len() - 2);
81    let local_t = (t * (colors.len() - 1) as f64) - idx as f64;
82    colors[idx].interpolate(&colors[idx + 1], local_t as f32)
83}
84
85/// Viridis colormap (matplotlib/d3)
86fn viridis(t: f64) -> D3Color {
87    let colors = [
88        D3Color::from_hex(0x440154),
89        D3Color::from_hex(0x482878),
90        D3Color::from_hex(0x3e4a89),
91        D3Color::from_hex(0x31688e),
92        D3Color::from_hex(0x26838f),
93        D3Color::from_hex(0x1f9e89),
94        D3Color::from_hex(0x35b779),
95        D3Color::from_hex(0x6ece58),
96        D3Color::from_hex(0xb5de2b),
97        D3Color::from_hex(0xfde725),
98    ];
99    interpolate_palette(t, &colors)
100}
101
102/// Plasma colormap (matplotlib/d3)
103fn plasma(t: f64) -> D3Color {
104    let colors = [
105        D3Color::from_hex(0x0d0887),
106        D3Color::from_hex(0x46039f),
107        D3Color::from_hex(0x7201a8),
108        D3Color::from_hex(0x9c179e),
109        D3Color::from_hex(0xbd3786),
110        D3Color::from_hex(0xd8576b),
111        D3Color::from_hex(0xed7953),
112        D3Color::from_hex(0xfb9f3a),
113        D3Color::from_hex(0xfdca26),
114        D3Color::from_hex(0xf0f921),
115    ];
116    interpolate_palette(t, &colors)
117}
118
119/// Inferno colormap (matplotlib/d3)
120fn inferno(t: f64) -> D3Color {
121    let colors = [
122        D3Color::from_hex(0x000004),
123        D3Color::from_hex(0x1b0c41),
124        D3Color::from_hex(0x4a0c6b),
125        D3Color::from_hex(0x781c6d),
126        D3Color::from_hex(0xa52c60),
127        D3Color::from_hex(0xcf4446),
128        D3Color::from_hex(0xed6925),
129        D3Color::from_hex(0xfb9b06),
130        D3Color::from_hex(0xf7d13d),
131        D3Color::from_hex(0xfcffa4),
132    ];
133    interpolate_palette(t, &colors)
134}
135
136/// Magma colormap (matplotlib/d3)
137fn magma(t: f64) -> D3Color {
138    let colors = [
139        D3Color::from_hex(0x000004),
140        D3Color::from_hex(0x180f3d),
141        D3Color::from_hex(0x440f76),
142        D3Color::from_hex(0x721f81),
143        D3Color::from_hex(0x9e2f7f),
144        D3Color::from_hex(0xcd4071),
145        D3Color::from_hex(0xf1605d),
146        D3Color::from_hex(0xfd9668),
147        D3Color::from_hex(0xfeca8d),
148        D3Color::from_hex(0xfcfdbf),
149    ];
150    interpolate_palette(t, &colors)
151}
152
153/// Heat colormap (blue → white → red)
154fn heat(t: f64) -> D3Color {
155    if t < 0.5 {
156        let local_t = t * 2.0;
157        D3Color::from_hex(0x0571b0).interpolate(&D3Color::from_hex(0xf7f7f7), local_t as f32)
158    } else {
159        let local_t = (t - 0.5) * 2.0;
160        D3Color::from_hex(0xf7f7f7).interpolate(&D3Color::from_hex(0xca0020), local_t as f32)
161    }
162}
163
164/// Coolwarm colormap (diverging)
165fn coolwarm(t: f64) -> D3Color {
166    let colors = [
167        D3Color::from_hex(0x3b4cc0),
168        D3Color::from_hex(0x6788ee),
169        D3Color::from_hex(0x9abbff),
170        D3Color::from_hex(0xc9d7f0),
171        D3Color::from_hex(0xedd1c2),
172        D3Color::from_hex(0xf7a789),
173        D3Color::from_hex(0xe36a53),
174        D3Color::from_hex(0xb40426),
175    ];
176    interpolate_palette(t, &colors)
177}
178
179/// Greys colormap (white → black)
180fn greys(t: f64) -> D3Color {
181    D3Color::from_hex(0xffffff).interpolate(&D3Color::from_hex(0x000000), t as f32)
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187
188    #[test]
189    fn test_viridis_endpoints() {
190        let start = ColorScale::Viridis.map(0.0);
191        let end = ColorScale::Viridis.map(1.0);
192        // Viridis starts dark purple, ends yellow
193        assert!(start.r() < 0.3);
194        assert!(end.g() > 0.8);
195    }
196
197    #[test]
198    fn test_plasma_endpoints() {
199        let start = ColorScale::Plasma.map(0.0);
200        let end = ColorScale::Plasma.map(1.0);
201        // Plasma starts dark blue, ends light yellow
202        assert!(start.b() > 0.3);
203        assert!(end.r() > 0.8);
204    }
205
206    #[test]
207    fn test_heat_diverging() {
208        let low = ColorScale::Heat.map(0.0);
209        let mid = ColorScale::Heat.map(0.5);
210        let high = ColorScale::Heat.map(1.0);
211        // Low is blue, mid is white-ish, high is red
212        assert!(low.b() > low.r());
213        assert!(mid.r() > 0.9 && mid.g() > 0.9 && mid.b() > 0.9);
214        assert!(high.r() > high.b());
215    }
216
217    #[test]
218    fn test_greys_endpoints() {
219        let start = ColorScale::Greys.map(0.0);
220        let end = ColorScale::Greys.map(1.0);
221        // Start is white, end is black
222        assert!(start.r() > 0.99 && start.g() > 0.99 && start.b() > 0.99);
223        assert!(end.r() < 0.01 && end.g() < 0.01 && end.b() < 0.01);
224    }
225
226    #[test]
227    fn test_custom_scale() {
228        let scale = ColorScale::custom(|t| {
229            D3Color::from_hex(0xff0000).interpolate(&D3Color::from_hex(0x00ff00), t as f32)
230        });
231        let mid = scale.map(0.5);
232        // Mid should be yellow-ish (red + green)
233        assert!(mid.r() > 0.4 && mid.g() > 0.4);
234    }
235
236    #[test]
237    fn test_clamp_out_of_bounds() {
238        // Values outside [0, 1] should be clamped
239        let below = ColorScale::Viridis.map(-0.5);
240        let above = ColorScale::Viridis.map(1.5);
241        let at_zero = ColorScale::Viridis.map(0.0);
242        let at_one = ColorScale::Viridis.map(1.0);
243
244        // Should be clamped to endpoints
245        assert_eq!(below.r(), at_zero.r());
246        assert_eq!(above.r(), at_one.r());
247    }
248
249    #[test]
250    fn test_to_fn() {
251        let f = ColorScale::Viridis.to_fn();
252        let direct = ColorScale::Viridis.map(0.5);
253        let via_fn = f(0.5);
254        assert_eq!(direct.r(), via_fn.r());
255        assert_eq!(direct.g(), via_fn.g());
256        assert_eq!(direct.b(), via_fn.b());
257    }
258
259    #[test]
260    fn test_default() {
261        let scale = ColorScale::default();
262        assert!(matches!(scale, ColorScale::Viridis));
263    }
264
265    #[test]
266    fn test_debug() {
267        assert_eq!(format!("{:?}", ColorScale::Viridis), "ColorScale::Viridis");
268        assert_eq!(format!("{:?}", ColorScale::Heat), "ColorScale::Heat");
269        let custom = ColorScale::custom(|_| D3Color::from_hex(0x000000));
270        assert_eq!(format!("{:?}", custom), "ColorScale::Custom(...)");
271    }
272}