1use crate::color_scale::ColorScale;
4use crate::error::ChartError;
5use crate::{
6 DEFAULT_HEIGHT, DEFAULT_TITLE_FONT_SIZE, DEFAULT_WIDTH, ScaleType, TITLE_AREA_HEIGHT,
7 extent_padded, validate_data_array, validate_dimensions, validate_grid_dimensions,
8 validate_monotonic, validate_positive,
9};
10use d3rs::scale::{LinearScale, LogScale};
11use d3rs::shape::{ContourConfig, HeatmapData, render_heatmap};
12use d3rs::text::{VectorFontConfig, render_vector_text};
13use gpui::prelude::*;
14use gpui::*;
15
16#[derive(Clone)]
18pub struct HeatmapChart {
19 z: Vec<f64>,
20 grid_width: usize,
21 grid_height: usize,
22 x_values: Option<Vec<f64>>,
23 y_values: Option<Vec<f64>>,
24 x_scale_type: ScaleType,
25 y_scale_type: ScaleType,
26 color_scale: ColorScale,
27 title: Option<String>,
28 opacity: f32,
29 width: f32,
30 height: f32,
31}
32
33impl std::fmt::Debug for HeatmapChart {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("HeatmapChart")
36 .field("grid_width", &self.grid_width)
37 .field("grid_height", &self.grid_height)
38 .field("x_scale_type", &self.x_scale_type)
39 .field("y_scale_type", &self.y_scale_type)
40 .field("color_scale", &self.color_scale)
41 .field("title", &self.title)
42 .field("opacity", &self.opacity)
43 .field("width", &self.width)
44 .field("height", &self.height)
45 .finish()
46 }
47}
48
49impl HeatmapChart {
50 pub fn x(mut self, values: &[f64]) -> Self {
55 self.x_values = Some(values.to_vec());
56 self
57 }
58
59 pub fn y(mut self, values: &[f64]) -> Self {
64 self.y_values = Some(values.to_vec());
65 self
66 }
67
68 pub fn x_scale(mut self, scale: ScaleType) -> Self {
70 self.x_scale_type = scale;
71 self
72 }
73
74 pub fn y_scale(mut self, scale: ScaleType) -> Self {
76 self.y_scale_type = scale;
77 self
78 }
79
80 pub fn color_scale(mut self, scale: ColorScale) -> Self {
82 self.color_scale = scale;
83 self
84 }
85
86 pub fn title(mut self, title: impl Into<String>) -> Self {
88 self.title = Some(title.into());
89 self
90 }
91
92 pub fn opacity(mut self, opacity: f32) -> Self {
94 self.opacity = opacity.clamp(0.0, 1.0);
95 self
96 }
97
98 pub fn size(mut self, width: f32, height: f32) -> Self {
100 self.width = width;
101 self.height = height;
102 self
103 }
104
105 pub fn build(self) -> Result<impl IntoElement, ChartError> {
107 validate_data_array(&self.z, "z")?;
109 validate_grid_dimensions(&self.z, self.grid_width, self.grid_height)?;
110 validate_dimensions(self.width, self.height)?;
111
112 let x_values = match self.x_values {
114 Some(ref v) => {
115 if v.len() != self.grid_width {
116 return Err(ChartError::DataLengthMismatch {
117 x_field: "x",
118 y_field: "grid_width",
119 x_len: v.len(),
120 y_len: self.grid_width,
121 });
122 }
123 validate_data_array(v, "x")?;
124 validate_monotonic(v, "x")?;
125 if self.x_scale_type == ScaleType::Log {
126 validate_positive(v, "x")?;
127 }
128 v.clone()
129 }
130 None => (0..self.grid_width).map(|i| i as f64).collect(),
131 };
132
133 let y_values = match self.y_values {
135 Some(ref v) => {
136 if v.len() != self.grid_height {
137 return Err(ChartError::DataLengthMismatch {
138 x_field: "y",
139 y_field: "grid_height",
140 x_len: v.len(),
141 y_len: self.grid_height,
142 });
143 }
144 validate_data_array(v, "y")?;
145 validate_monotonic(v, "y")?;
146 if self.y_scale_type == ScaleType::Log {
147 validate_positive(v, "y")?;
148 }
149 v.clone()
150 }
151 None => (0..self.grid_height).map(|i| i as f64).collect(),
152 };
153
154 let title_height = if self.title.is_some() {
156 TITLE_AREA_HEIGHT
157 } else {
158 0.0
159 };
160 let plot_height = self.height - title_height;
161
162 let (x_min, x_max) = extent_padded(&x_values, 0.0);
164 let (y_min, y_max) = extent_padded(&y_values, 0.0);
165
166 let heatmap_data = HeatmapData::new(x_values, y_values, self.z.clone());
168
169 let color_fn = self.color_scale.to_fn();
171 let config = ContourConfig::new()
172 .fill(true)
173 .fill_opacity(self.opacity)
174 .color_scale(color_fn);
175
176 let heatmap_element: AnyElement = match (self.x_scale_type, self.y_scale_type) {
178 (ScaleType::Linear, ScaleType::Linear) => {
179 let x_scale = LinearScale::new()
180 .domain(x_min, x_max)
181 .range(0.0, self.width as f64);
182 let y_scale = LinearScale::new()
183 .domain(y_min, y_max)
184 .range(plot_height as f64, 0.0);
185 render_heatmap(heatmap_data, &x_scale, &y_scale, &config).into_any_element()
186 }
187 (ScaleType::Log, ScaleType::Linear) => {
188 let x_scale = LogScale::new()
189 .domain(x_min.max(1e-10), x_max)
190 .range(0.0, self.width as f64);
191 let y_scale = LinearScale::new()
192 .domain(y_min, y_max)
193 .range(plot_height as f64, 0.0);
194 render_heatmap(heatmap_data, &x_scale, &y_scale, &config).into_any_element()
195 }
196 (ScaleType::Linear, ScaleType::Log) => {
197 let x_scale = LinearScale::new()
198 .domain(x_min, x_max)
199 .range(0.0, self.width as f64);
200 let y_scale = LogScale::new()
201 .domain(y_min.max(1e-10), y_max)
202 .range(plot_height as f64, 0.0);
203 render_heatmap(heatmap_data, &x_scale, &y_scale, &config).into_any_element()
204 }
205 (ScaleType::Log, ScaleType::Log) => {
206 let x_scale = LogScale::new()
207 .domain(x_min.max(1e-10), x_max)
208 .range(0.0, self.width as f64);
209 let y_scale = LogScale::new()
210 .domain(y_min.max(1e-10), y_max)
211 .range(plot_height as f64, 0.0);
212 render_heatmap(heatmap_data, &x_scale, &y_scale, &config).into_any_element()
213 }
214 };
215
216 let mut container = div()
218 .w(px(self.width))
219 .h(px(self.height))
220 .relative()
221 .flex()
222 .flex_col();
223
224 if let Some(title) = &self.title {
226 let font_config =
227 VectorFontConfig::horizontal(DEFAULT_TITLE_FONT_SIZE, hsla(0.0, 0.0, 0.2, 1.0));
228 container = container.child(
229 div()
230 .w_full()
231 .h(px(title_height))
232 .flex()
233 .justify_center()
234 .items_center()
235 .child(render_vector_text(title, &font_config)),
236 );
237 }
238
239 container = container.child(
241 div()
242 .w(px(self.width))
243 .h(px(plot_height))
244 .relative()
245 .child(heatmap_element),
246 );
247
248 Ok(container)
249 }
250}
251
252pub fn heatmap(z: &[f64], grid_width: usize, grid_height: usize) -> HeatmapChart {
293 HeatmapChart {
294 z: z.to_vec(),
295 grid_width,
296 grid_height,
297 x_values: None,
298 y_values: None,
299 x_scale_type: ScaleType::Linear,
300 y_scale_type: ScaleType::Linear,
301 color_scale: ColorScale::default(),
302 title: None,
303 opacity: 1.0,
304 width: DEFAULT_WIDTH,
305 height: DEFAULT_HEIGHT,
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn test_heatmap_empty_z() {
315 let result = heatmap(&[], 0, 0).build();
316 assert!(matches!(result, Err(ChartError::EmptyData { field: "z" })));
317 }
318
319 #[test]
320 fn test_heatmap_grid_mismatch() {
321 let z = vec![1.0, 2.0, 3.0, 4.0, 5.0]; let result = heatmap(&z, 2, 3).build(); assert!(matches!(
324 result,
325 Err(ChartError::GridDimensionMismatch {
326 z_len: 5,
327 width: 2,
328 height: 3,
329 expected: 6,
330 })
331 ));
332 }
333
334 #[test]
335 fn test_heatmap_x_length_mismatch() {
336 let z = vec![1.0; 6]; let x = vec![0.0, 1.0, 2.0]; let result = heatmap(&z, 2, 3).x(&x).build();
339 assert!(matches!(
340 result,
341 Err(ChartError::DataLengthMismatch {
342 x_field: "x",
343 y_field: "grid_width",
344 x_len: 3,
345 y_len: 2,
346 })
347 ));
348 }
349
350 #[test]
351 fn test_heatmap_y_length_mismatch() {
352 let z = vec![1.0; 6]; let y = vec![0.0, 1.0]; let result = heatmap(&z, 2, 3).y(&y).build();
355 assert!(matches!(
356 result,
357 Err(ChartError::DataLengthMismatch {
358 x_field: "y",
359 y_field: "grid_height",
360 x_len: 2,
361 y_len: 3,
362 })
363 ));
364 }
365
366 #[test]
367 fn test_heatmap_non_monotonic_x() {
368 let z = vec![1.0; 4]; let x = vec![1.0, 0.0]; let result = heatmap(&z, 2, 2).x(&x).build();
371 assert!(matches!(
372 result,
373 Err(ChartError::InvalidData {
374 field: "x",
375 reason: "must be strictly monotonically increasing"
376 })
377 ));
378 }
379
380 #[test]
381 fn test_heatmap_log_scale_negative() {
382 let z = vec![1.0; 4]; let x = vec![-1.0, 1.0]; let result = heatmap(&z, 2, 2).x(&x).x_scale(ScaleType::Log).build();
385 assert!(matches!(
386 result,
387 Err(ChartError::InvalidData {
388 field: "x",
389 reason: "log scale requires positive values"
390 })
391 ));
392 }
393
394 #[test]
395 fn test_heatmap_successful_build() {
396 let z = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let result = heatmap(&z, 2, 3)
398 .title("Test Heatmap")
399 .color_scale(ColorScale::Viridis)
400 .build();
401 assert!(result.is_ok());
402 }
403
404 #[test]
405 fn test_heatmap_with_custom_axes() {
406 let z = vec![1.0; 6]; let x = vec![10.0, 100.0];
408 let y = vec![0.0, 1.0, 2.0];
409 let result = heatmap(&z, 2, 3).x(&x).y(&y).build();
410 assert!(result.is_ok());
411 }
412
413 #[test]
414 fn test_heatmap_log_scale() {
415 let z = vec![1.0; 4]; let x = vec![10.0, 100.0];
417 let y = vec![1.0, 10.0];
418 let result = heatmap(&z, 2, 2)
419 .x(&x)
420 .y(&y)
421 .x_scale(ScaleType::Log)
422 .y_scale(ScaleType::Log)
423 .build();
424 assert!(result.is_ok());
425 }
426
427 #[test]
428 fn test_heatmap_builder_chain() {
429 let z = vec![1.0; 9]; let result = heatmap(&z, 3, 3)
431 .title("My Heatmap")
432 .color_scale(ColorScale::Plasma)
433 .opacity(0.8)
434 .size(800.0, 600.0)
435 .build();
436 assert!(result.is_ok());
437 }
438}