gpui_px/
color_scale.rs

1//! Color scales for heatmaps, contours, and isolines.
2
3use d3rs::color::D3Color;
4use std::sync::Arc;
5
6/// Color scale for 2D visualizations (heatmaps, contours).
7#[derive(Clone, Default)]
8pub enum ColorScale {
9    /// Viridis - perceptually uniform, colorblind-friendly (purple → yellow).
10    #[default]
11    Viridis,
12    /// Plasma - perceptually uniform (purple → orange → yellow).
13    Plasma,
14    /// Inferno - perceptually uniform (black → purple → orange → yellow).
15    Inferno,
16    /// Magma - perceptually uniform (black → purple → orange → white).
17    Magma,
18    /// Heat - diverging (blue → white → red).
19    Heat,
20    /// Coolwarm - diverging (cool blue → neutral → warm red).
21    Coolwarm,
22    /// Greys - sequential grayscale (white → black).
23    Greys,
24    /// Custom color scale function.
25    Custom(Arc<dyn Fn(f64) -> D3Color + Send + Sync>),
26}
27
28impl std::fmt::Debug for ColorScale {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        match self {
31            ColorScale::Viridis => write!(f, "ColorScale::Viridis"),
32            ColorScale::Plasma => write!(f, "ColorScale::Plasma"),
33            ColorScale::Inferno => write!(f, "ColorScale::Inferno"),
34            ColorScale::Magma => write!(f, "ColorScale::Magma"),
35            ColorScale::Heat => write!(f, "ColorScale::Heat"),
36            ColorScale::Coolwarm => write!(f, "ColorScale::Coolwarm"),
37            ColorScale::Greys => write!(f, "ColorScale::Greys"),
38            ColorScale::Custom(_) => write!(f, "ColorScale::Custom(...)"),
39        }
40    }
41}
42
43
44impl ColorScale {
45    /// Create a custom color scale from a function.
46    ///
47    /// The function should map values in [0, 1] to colors.
48    pub fn custom<F>(f: F) -> Self
49    where
50        F: Fn(f64) -> D3Color + Send + Sync + 'static,
51    {
52        ColorScale::Custom(Arc::new(f))
53    }
54
55    /// Convert to a function that maps [0, 1] → D3Color.
56    pub fn to_fn(&self) -> impl Fn(f64) -> D3Color + Send + Sync + Clone + 'static {
57        let scale = self.clone();
58        move |t: f64| scale.map(t)
59    }
60
61    /// Map a value in [0, 1] to a color.
62    pub fn map(&self, t: f64) -> D3Color {
63        let t = t.clamp(0.0, 1.0);
64
65        match self {
66            ColorScale::Viridis => viridis(t),
67            ColorScale::Plasma => plasma(t),
68            ColorScale::Inferno => inferno(t),
69            ColorScale::Magma => magma(t),
70            ColorScale::Heat => heat(t),
71            ColorScale::Coolwarm => coolwarm(t),
72            ColorScale::Greys => greys(t),
73            ColorScale::Custom(f) => f(t),
74        }
75    }
76}
77
78// Helper function to interpolate between colors in a palette
79fn interpolate_palette(t: f64, colors: &[D3Color]) -> D3Color {
80    let idx = (t * (colors.len() - 1) as f64) as usize;
81    let idx = idx.min(colors.len() - 2);
82    let local_t = (t * (colors.len() - 1) as f64) - idx as f64;
83    colors[idx].interpolate(&colors[idx + 1], local_t as f32)
84}
85
86/// Viridis colormap (matplotlib/d3)
87fn viridis(t: f64) -> D3Color {
88    let colors = [
89        D3Color::from_hex(0x440154),
90        D3Color::from_hex(0x482878),
91        D3Color::from_hex(0x3e4a89),
92        D3Color::from_hex(0x31688e),
93        D3Color::from_hex(0x26838f),
94        D3Color::from_hex(0x1f9e89),
95        D3Color::from_hex(0x35b779),
96        D3Color::from_hex(0x6ece58),
97        D3Color::from_hex(0xb5de2b),
98        D3Color::from_hex(0xfde725),
99    ];
100    interpolate_palette(t, &colors)
101}
102
103/// Plasma colormap (matplotlib/d3)
104fn plasma(t: f64) -> D3Color {
105    let colors = [
106        D3Color::from_hex(0x0d0887),
107        D3Color::from_hex(0x46039f),
108        D3Color::from_hex(0x7201a8),
109        D3Color::from_hex(0x9c179e),
110        D3Color::from_hex(0xbd3786),
111        D3Color::from_hex(0xd8576b),
112        D3Color::from_hex(0xed7953),
113        D3Color::from_hex(0xfb9f3a),
114        D3Color::from_hex(0xfdca26),
115        D3Color::from_hex(0xf0f921),
116    ];
117    interpolate_palette(t, &colors)
118}
119
120/// Inferno colormap (matplotlib/d3)
121fn inferno(t: f64) -> D3Color {
122    let colors = [
123        D3Color::from_hex(0x000004),
124        D3Color::from_hex(0x1b0c41),
125        D3Color::from_hex(0x4a0c6b),
126        D3Color::from_hex(0x781c6d),
127        D3Color::from_hex(0xa52c60),
128        D3Color::from_hex(0xcf4446),
129        D3Color::from_hex(0xed6925),
130        D3Color::from_hex(0xfb9b06),
131        D3Color::from_hex(0xf7d13d),
132        D3Color::from_hex(0xfcffa4),
133    ];
134    interpolate_palette(t, &colors)
135}
136
137/// Magma colormap (matplotlib/d3)
138fn magma(t: f64) -> D3Color {
139    let colors = [
140        D3Color::from_hex(0x000004),
141        D3Color::from_hex(0x180f3d),
142        D3Color::from_hex(0x440f76),
143        D3Color::from_hex(0x721f81),
144        D3Color::from_hex(0x9e2f7f),
145        D3Color::from_hex(0xcd4071),
146        D3Color::from_hex(0xf1605d),
147        D3Color::from_hex(0xfd9668),
148        D3Color::from_hex(0xfeca8d),
149        D3Color::from_hex(0xfcfdbf),
150    ];
151    interpolate_palette(t, &colors)
152}
153
154/// Heat colormap (blue → white → red)
155fn heat(t: f64) -> D3Color {
156    if t < 0.5 {
157        let local_t = t * 2.0;
158        D3Color::from_hex(0x0571b0).interpolate(&D3Color::from_hex(0xf7f7f7), local_t as f32)
159    } else {
160        let local_t = (t - 0.5) * 2.0;
161        D3Color::from_hex(0xf7f7f7).interpolate(&D3Color::from_hex(0xca0020), local_t as f32)
162    }
163}
164
165/// Coolwarm colormap (diverging)
166fn coolwarm(t: f64) -> D3Color {
167    let colors = [
168        D3Color::from_hex(0x3b4cc0),
169        D3Color::from_hex(0x6788ee),
170        D3Color::from_hex(0x9abbff),
171        D3Color::from_hex(0xc9d7f0),
172        D3Color::from_hex(0xedd1c2),
173        D3Color::from_hex(0xf7a789),
174        D3Color::from_hex(0xe36a53),
175        D3Color::from_hex(0xb40426),
176    ];
177    interpolate_palette(t, &colors)
178}
179
180/// Greys colormap (white → black)
181fn greys(t: f64) -> D3Color {
182    D3Color::from_hex(0xffffff).interpolate(&D3Color::from_hex(0x000000), t as f32)
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    #[test]
190    fn test_viridis_endpoints() {
191        let start = ColorScale::Viridis.map(0.0);
192        let end = ColorScale::Viridis.map(1.0);
193        // Viridis starts dark purple, ends yellow
194        assert!(start.r() < 0.3);
195        assert!(end.g() > 0.8);
196    }
197
198    #[test]
199    fn test_plasma_endpoints() {
200        let start = ColorScale::Plasma.map(0.0);
201        let end = ColorScale::Plasma.map(1.0);
202        // Plasma starts dark blue, ends light yellow
203        assert!(start.b() > 0.3);
204        assert!(end.r() > 0.8);
205    }
206
207    #[test]
208    fn test_heat_diverging() {
209        let low = ColorScale::Heat.map(0.0);
210        let mid = ColorScale::Heat.map(0.5);
211        let high = ColorScale::Heat.map(1.0);
212        // Low is blue, mid is white-ish, high is red
213        assert!(low.b() > low.r());
214        assert!(mid.r() > 0.9 && mid.g() > 0.9 && mid.b() > 0.9);
215        assert!(high.r() > high.b());
216    }
217
218    #[test]
219    fn test_greys_endpoints() {
220        let start = ColorScale::Greys.map(0.0);
221        let end = ColorScale::Greys.map(1.0);
222        // Start is white, end is black
223        assert!(start.r() > 0.99 && start.g() > 0.99 && start.b() > 0.99);
224        assert!(end.r() < 0.01 && end.g() < 0.01 && end.b() < 0.01);
225    }
226
227    #[test]
228    fn test_custom_scale() {
229        let scale = ColorScale::custom(|t| {
230            D3Color::from_hex(0xff0000).interpolate(&D3Color::from_hex(0x00ff00), t as f32)
231        });
232        let mid = scale.map(0.5);
233        // Mid should be yellow-ish (red + green)
234        assert!(mid.r() > 0.4 && mid.g() > 0.4);
235    }
236
237    #[test]
238    fn test_clamp_out_of_bounds() {
239        // Values outside [0, 1] should be clamped
240        let below = ColorScale::Viridis.map(-0.5);
241        let above = ColorScale::Viridis.map(1.5);
242        let at_zero = ColorScale::Viridis.map(0.0);
243        let at_one = ColorScale::Viridis.map(1.0);
244
245        // Should be clamped to endpoints
246        assert_eq!(below.r(), at_zero.r());
247        assert_eq!(above.r(), at_one.r());
248    }
249
250    #[test]
251    fn test_to_fn() {
252        let f = ColorScale::Viridis.to_fn();
253        let direct = ColorScale::Viridis.map(0.5);
254        let via_fn = f(0.5);
255        assert_eq!(direct.r(), via_fn.r());
256        assert_eq!(direct.g(), via_fn.g());
257        assert_eq!(direct.b(), via_fn.b());
258    }
259
260    #[test]
261    fn test_default() {
262        let scale = ColorScale::default();
263        assert!(matches!(scale, ColorScale::Viridis));
264    }
265
266    #[test]
267    fn test_debug() {
268        assert_eq!(format!("{:?}", ColorScale::Viridis), "ColorScale::Viridis");
269        assert_eq!(format!("{:?}", ColorScale::Heat), "ColorScale::Heat");
270        let custom = ColorScale::custom(|_| D3Color::from_hex(0x000000));
271        assert_eq!(format!("{:?}", custom), "ColorScale::Custom(...)");
272    }
273}