1use d3rs::color::D3Color;
4use std::sync::Arc;
5
6#[derive(Clone, Default)]
8pub enum ColorScale {
9 #[default]
11 Viridis,
12 Plasma,
14 Inferno,
16 Magma,
18 Heat,
20 Coolwarm,
22 Greys,
24 Custom(Arc<dyn Fn(f64) -> D3Color + Send + Sync>),
26}
27
28impl std::fmt::Debug for ColorScale {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 match self {
31 ColorScale::Viridis => write!(f, "ColorScale::Viridis"),
32 ColorScale::Plasma => write!(f, "ColorScale::Plasma"),
33 ColorScale::Inferno => write!(f, "ColorScale::Inferno"),
34 ColorScale::Magma => write!(f, "ColorScale::Magma"),
35 ColorScale::Heat => write!(f, "ColorScale::Heat"),
36 ColorScale::Coolwarm => write!(f, "ColorScale::Coolwarm"),
37 ColorScale::Greys => write!(f, "ColorScale::Greys"),
38 ColorScale::Custom(_) => write!(f, "ColorScale::Custom(...)"),
39 }
40 }
41}
42
43impl ColorScale {
44 pub fn custom<F>(f: F) -> Self
48 where
49 F: Fn(f64) -> D3Color + Send + Sync + 'static,
50 {
51 ColorScale::Custom(Arc::new(f))
52 }
53
54 pub fn to_fn(&self) -> impl Fn(f64) -> D3Color + Send + Sync + Clone + 'static {
56 let scale = self.clone();
57 move |t: f64| scale.map(t)
58 }
59
60 pub fn map(&self, t: f64) -> D3Color {
62 let t = t.clamp(0.0, 1.0);
63
64 match self {
65 ColorScale::Viridis => viridis(t),
66 ColorScale::Plasma => plasma(t),
67 ColorScale::Inferno => inferno(t),
68 ColorScale::Magma => magma(t),
69 ColorScale::Heat => heat(t),
70 ColorScale::Coolwarm => coolwarm(t),
71 ColorScale::Greys => greys(t),
72 ColorScale::Custom(f) => f(t),
73 }
74 }
75}
76
77fn interpolate_palette(t: f64, colors: &[D3Color]) -> D3Color {
79 let idx = (t * (colors.len() - 1) as f64) as usize;
80 let idx = idx.min(colors.len() - 2);
81 let local_t = (t * (colors.len() - 1) as f64) - idx as f64;
82 colors[idx].interpolate(&colors[idx + 1], local_t as f32)
83}
84
85fn viridis(t: f64) -> D3Color {
87 let colors = [
88 D3Color::from_hex(0x440154),
89 D3Color::from_hex(0x482878),
90 D3Color::from_hex(0x3e4a89),
91 D3Color::from_hex(0x31688e),
92 D3Color::from_hex(0x26838f),
93 D3Color::from_hex(0x1f9e89),
94 D3Color::from_hex(0x35b779),
95 D3Color::from_hex(0x6ece58),
96 D3Color::from_hex(0xb5de2b),
97 D3Color::from_hex(0xfde725),
98 ];
99 interpolate_palette(t, &colors)
100}
101
102fn plasma(t: f64) -> D3Color {
104 let colors = [
105 D3Color::from_hex(0x0d0887),
106 D3Color::from_hex(0x46039f),
107 D3Color::from_hex(0x7201a8),
108 D3Color::from_hex(0x9c179e),
109 D3Color::from_hex(0xbd3786),
110 D3Color::from_hex(0xd8576b),
111 D3Color::from_hex(0xed7953),
112 D3Color::from_hex(0xfb9f3a),
113 D3Color::from_hex(0xfdca26),
114 D3Color::from_hex(0xf0f921),
115 ];
116 interpolate_palette(t, &colors)
117}
118
119fn inferno(t: f64) -> D3Color {
121 let colors = [
122 D3Color::from_hex(0x000004),
123 D3Color::from_hex(0x1b0c41),
124 D3Color::from_hex(0x4a0c6b),
125 D3Color::from_hex(0x781c6d),
126 D3Color::from_hex(0xa52c60),
127 D3Color::from_hex(0xcf4446),
128 D3Color::from_hex(0xed6925),
129 D3Color::from_hex(0xfb9b06),
130 D3Color::from_hex(0xf7d13d),
131 D3Color::from_hex(0xfcffa4),
132 ];
133 interpolate_palette(t, &colors)
134}
135
136fn magma(t: f64) -> D3Color {
138 let colors = [
139 D3Color::from_hex(0x000004),
140 D3Color::from_hex(0x180f3d),
141 D3Color::from_hex(0x440f76),
142 D3Color::from_hex(0x721f81),
143 D3Color::from_hex(0x9e2f7f),
144 D3Color::from_hex(0xcd4071),
145 D3Color::from_hex(0xf1605d),
146 D3Color::from_hex(0xfd9668),
147 D3Color::from_hex(0xfeca8d),
148 D3Color::from_hex(0xfcfdbf),
149 ];
150 interpolate_palette(t, &colors)
151}
152
153fn heat(t: f64) -> D3Color {
155 if t < 0.5 {
156 let local_t = t * 2.0;
157 D3Color::from_hex(0x0571b0).interpolate(&D3Color::from_hex(0xf7f7f7), local_t as f32)
158 } else {
159 let local_t = (t - 0.5) * 2.0;
160 D3Color::from_hex(0xf7f7f7).interpolate(&D3Color::from_hex(0xca0020), local_t as f32)
161 }
162}
163
164fn coolwarm(t: f64) -> D3Color {
166 let colors = [
167 D3Color::from_hex(0x3b4cc0),
168 D3Color::from_hex(0x6788ee),
169 D3Color::from_hex(0x9abbff),
170 D3Color::from_hex(0xc9d7f0),
171 D3Color::from_hex(0xedd1c2),
172 D3Color::from_hex(0xf7a789),
173 D3Color::from_hex(0xe36a53),
174 D3Color::from_hex(0xb40426),
175 ];
176 interpolate_palette(t, &colors)
177}
178
179fn greys(t: f64) -> D3Color {
181 D3Color::from_hex(0xffffff).interpolate(&D3Color::from_hex(0x000000), t as f32)
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 #[test]
189 fn test_viridis_endpoints() {
190 let start = ColorScale::Viridis.map(0.0);
191 let end = ColorScale::Viridis.map(1.0);
192 assert!(start.r() < 0.3);
194 assert!(end.g() > 0.8);
195 }
196
197 #[test]
198 fn test_plasma_endpoints() {
199 let start = ColorScale::Plasma.map(0.0);
200 let end = ColorScale::Plasma.map(1.0);
201 assert!(start.b() > 0.3);
203 assert!(end.r() > 0.8);
204 }
205
206 #[test]
207 fn test_heat_diverging() {
208 let low = ColorScale::Heat.map(0.0);
209 let mid = ColorScale::Heat.map(0.5);
210 let high = ColorScale::Heat.map(1.0);
211 assert!(low.b() > low.r());
213 assert!(mid.r() > 0.9 && mid.g() > 0.9 && mid.b() > 0.9);
214 assert!(high.r() > high.b());
215 }
216
217 #[test]
218 fn test_greys_endpoints() {
219 let start = ColorScale::Greys.map(0.0);
220 let end = ColorScale::Greys.map(1.0);
221 assert!(start.r() > 0.99 && start.g() > 0.99 && start.b() > 0.99);
223 assert!(end.r() < 0.01 && end.g() < 0.01 && end.b() < 0.01);
224 }
225
226 #[test]
227 fn test_custom_scale() {
228 let scale = ColorScale::custom(|t| {
229 D3Color::from_hex(0xff0000).interpolate(&D3Color::from_hex(0x00ff00), t as f32)
230 });
231 let mid = scale.map(0.5);
232 assert!(mid.r() > 0.4 && mid.g() > 0.4);
234 }
235
236 #[test]
237 fn test_clamp_out_of_bounds() {
238 let below = ColorScale::Viridis.map(-0.5);
240 let above = ColorScale::Viridis.map(1.5);
241 let at_zero = ColorScale::Viridis.map(0.0);
242 let at_one = ColorScale::Viridis.map(1.0);
243
244 assert_eq!(below.r(), at_zero.r());
246 assert_eq!(above.r(), at_one.r());
247 }
248
249 #[test]
250 fn test_to_fn() {
251 let f = ColorScale::Viridis.to_fn();
252 let direct = ColorScale::Viridis.map(0.5);
253 let via_fn = f(0.5);
254 assert_eq!(direct.r(), via_fn.r());
255 assert_eq!(direct.g(), via_fn.g());
256 assert_eq!(direct.b(), via_fn.b());
257 }
258
259 #[test]
260 fn test_default() {
261 let scale = ColorScale::default();
262 assert!(matches!(scale, ColorScale::Viridis));
263 }
264
265 #[test]
266 fn test_debug() {
267 assert_eq!(format!("{:?}", ColorScale::Viridis), "ColorScale::Viridis");
268 assert_eq!(format!("{:?}", ColorScale::Heat), "ColorScale::Heat");
269 let custom = ColorScale::custom(|_| D3Color::from_hex(0x000000));
270 assert_eq!(format!("{:?}", custom), "ColorScale::Custom(...)");
271 }
272}