1use d3rs::color::D3Color;
4use std::sync::Arc;
5
6#[derive(Clone, Default)]
8pub enum ColorScale {
9 #[default]
11 Viridis,
12 Plasma,
14 Inferno,
16 Magma,
18 Heat,
20 Coolwarm,
22 Greys,
24 Custom(Arc<dyn Fn(f64) -> D3Color + Send + Sync>),
26}
27
28impl std::fmt::Debug for ColorScale {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 match self {
31 ColorScale::Viridis => write!(f, "ColorScale::Viridis"),
32 ColorScale::Plasma => write!(f, "ColorScale::Plasma"),
33 ColorScale::Inferno => write!(f, "ColorScale::Inferno"),
34 ColorScale::Magma => write!(f, "ColorScale::Magma"),
35 ColorScale::Heat => write!(f, "ColorScale::Heat"),
36 ColorScale::Coolwarm => write!(f, "ColorScale::Coolwarm"),
37 ColorScale::Greys => write!(f, "ColorScale::Greys"),
38 ColorScale::Custom(_) => write!(f, "ColorScale::Custom(...)"),
39 }
40 }
41}
42
43
44impl ColorScale {
45 pub fn custom<F>(f: F) -> Self
49 where
50 F: Fn(f64) -> D3Color + Send + Sync + 'static,
51 {
52 ColorScale::Custom(Arc::new(f))
53 }
54
55 pub fn to_fn(&self) -> impl Fn(f64) -> D3Color + Send + Sync + Clone + 'static {
57 let scale = self.clone();
58 move |t: f64| scale.map(t)
59 }
60
61 pub fn map(&self, t: f64) -> D3Color {
63 let t = t.clamp(0.0, 1.0);
64
65 match self {
66 ColorScale::Viridis => viridis(t),
67 ColorScale::Plasma => plasma(t),
68 ColorScale::Inferno => inferno(t),
69 ColorScale::Magma => magma(t),
70 ColorScale::Heat => heat(t),
71 ColorScale::Coolwarm => coolwarm(t),
72 ColorScale::Greys => greys(t),
73 ColorScale::Custom(f) => f(t),
74 }
75 }
76}
77
78fn interpolate_palette(t: f64, colors: &[D3Color]) -> D3Color {
80 let idx = (t * (colors.len() - 1) as f64) as usize;
81 let idx = idx.min(colors.len() - 2);
82 let local_t = (t * (colors.len() - 1) as f64) - idx as f64;
83 colors[idx].interpolate(&colors[idx + 1], local_t as f32)
84}
85
86fn viridis(t: f64) -> D3Color {
88 let colors = [
89 D3Color::from_hex(0x440154),
90 D3Color::from_hex(0x482878),
91 D3Color::from_hex(0x3e4a89),
92 D3Color::from_hex(0x31688e),
93 D3Color::from_hex(0x26838f),
94 D3Color::from_hex(0x1f9e89),
95 D3Color::from_hex(0x35b779),
96 D3Color::from_hex(0x6ece58),
97 D3Color::from_hex(0xb5de2b),
98 D3Color::from_hex(0xfde725),
99 ];
100 interpolate_palette(t, &colors)
101}
102
103fn plasma(t: f64) -> D3Color {
105 let colors = [
106 D3Color::from_hex(0x0d0887),
107 D3Color::from_hex(0x46039f),
108 D3Color::from_hex(0x7201a8),
109 D3Color::from_hex(0x9c179e),
110 D3Color::from_hex(0xbd3786),
111 D3Color::from_hex(0xd8576b),
112 D3Color::from_hex(0xed7953),
113 D3Color::from_hex(0xfb9f3a),
114 D3Color::from_hex(0xfdca26),
115 D3Color::from_hex(0xf0f921),
116 ];
117 interpolate_palette(t, &colors)
118}
119
120fn inferno(t: f64) -> D3Color {
122 let colors = [
123 D3Color::from_hex(0x000004),
124 D3Color::from_hex(0x1b0c41),
125 D3Color::from_hex(0x4a0c6b),
126 D3Color::from_hex(0x781c6d),
127 D3Color::from_hex(0xa52c60),
128 D3Color::from_hex(0xcf4446),
129 D3Color::from_hex(0xed6925),
130 D3Color::from_hex(0xfb9b06),
131 D3Color::from_hex(0xf7d13d),
132 D3Color::from_hex(0xfcffa4),
133 ];
134 interpolate_palette(t, &colors)
135}
136
137fn magma(t: f64) -> D3Color {
139 let colors = [
140 D3Color::from_hex(0x000004),
141 D3Color::from_hex(0x180f3d),
142 D3Color::from_hex(0x440f76),
143 D3Color::from_hex(0x721f81),
144 D3Color::from_hex(0x9e2f7f),
145 D3Color::from_hex(0xcd4071),
146 D3Color::from_hex(0xf1605d),
147 D3Color::from_hex(0xfd9668),
148 D3Color::from_hex(0xfeca8d),
149 D3Color::from_hex(0xfcfdbf),
150 ];
151 interpolate_palette(t, &colors)
152}
153
154fn heat(t: f64) -> D3Color {
156 if t < 0.5 {
157 let local_t = t * 2.0;
158 D3Color::from_hex(0x0571b0).interpolate(&D3Color::from_hex(0xf7f7f7), local_t as f32)
159 } else {
160 let local_t = (t - 0.5) * 2.0;
161 D3Color::from_hex(0xf7f7f7).interpolate(&D3Color::from_hex(0xca0020), local_t as f32)
162 }
163}
164
165fn coolwarm(t: f64) -> D3Color {
167 let colors = [
168 D3Color::from_hex(0x3b4cc0),
169 D3Color::from_hex(0x6788ee),
170 D3Color::from_hex(0x9abbff),
171 D3Color::from_hex(0xc9d7f0),
172 D3Color::from_hex(0xedd1c2),
173 D3Color::from_hex(0xf7a789),
174 D3Color::from_hex(0xe36a53),
175 D3Color::from_hex(0xb40426),
176 ];
177 interpolate_palette(t, &colors)
178}
179
180fn greys(t: f64) -> D3Color {
182 D3Color::from_hex(0xffffff).interpolate(&D3Color::from_hex(0x000000), t as f32)
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188
189 #[test]
190 fn test_viridis_endpoints() {
191 let start = ColorScale::Viridis.map(0.0);
192 let end = ColorScale::Viridis.map(1.0);
193 assert!(start.r() < 0.3);
195 assert!(end.g() > 0.8);
196 }
197
198 #[test]
199 fn test_plasma_endpoints() {
200 let start = ColorScale::Plasma.map(0.0);
201 let end = ColorScale::Plasma.map(1.0);
202 assert!(start.b() > 0.3);
204 assert!(end.r() > 0.8);
205 }
206
207 #[test]
208 fn test_heat_diverging() {
209 let low = ColorScale::Heat.map(0.0);
210 let mid = ColorScale::Heat.map(0.5);
211 let high = ColorScale::Heat.map(1.0);
212 assert!(low.b() > low.r());
214 assert!(mid.r() > 0.9 && mid.g() > 0.9 && mid.b() > 0.9);
215 assert!(high.r() > high.b());
216 }
217
218 #[test]
219 fn test_greys_endpoints() {
220 let start = ColorScale::Greys.map(0.0);
221 let end = ColorScale::Greys.map(1.0);
222 assert!(start.r() > 0.99 && start.g() > 0.99 && start.b() > 0.99);
224 assert!(end.r() < 0.01 && end.g() < 0.01 && end.b() < 0.01);
225 }
226
227 #[test]
228 fn test_custom_scale() {
229 let scale = ColorScale::custom(|t| {
230 D3Color::from_hex(0xff0000).interpolate(&D3Color::from_hex(0x00ff00), t as f32)
231 });
232 let mid = scale.map(0.5);
233 assert!(mid.r() > 0.4 && mid.g() > 0.4);
235 }
236
237 #[test]
238 fn test_clamp_out_of_bounds() {
239 let below = ColorScale::Viridis.map(-0.5);
241 let above = ColorScale::Viridis.map(1.5);
242 let at_zero = ColorScale::Viridis.map(0.0);
243 let at_one = ColorScale::Viridis.map(1.0);
244
245 assert_eq!(below.r(), at_zero.r());
247 assert_eq!(above.r(), at_one.r());
248 }
249
250 #[test]
251 fn test_to_fn() {
252 let f = ColorScale::Viridis.to_fn();
253 let direct = ColorScale::Viridis.map(0.5);
254 let via_fn = f(0.5);
255 assert_eq!(direct.r(), via_fn.r());
256 assert_eq!(direct.g(), via_fn.g());
257 assert_eq!(direct.b(), via_fn.b());
258 }
259
260 #[test]
261 fn test_default() {
262 let scale = ColorScale::default();
263 assert!(matches!(scale, ColorScale::Viridis));
264 }
265
266 #[test]
267 fn test_debug() {
268 assert_eq!(format!("{:?}", ColorScale::Viridis), "ColorScale::Viridis");
269 assert_eq!(format!("{:?}", ColorScale::Heat), "ColorScale::Heat");
270 let custom = ColorScale::custom(|_| D3Color::from_hex(0x000000));
271 assert_eq!(format!("{:?}", custom), "ColorScale::Custom(...)");
272 }
273}