1use crate::csg::prelude::*;
7use burn::prelude::*;
8use serde_json::{Value, json};
9use std::error::Error;
10
11#[derive(Debug, Clone)]
13pub struct FieldVisualizationConfig {
14 pub resolution: [usize; 3],
16 pub bounds: [[f32; 3]; 2],
18 pub color_scheme: ColorScheme,
20 pub opacity: f32,
22 pub mode: RenderingMode,
24}
25
26#[derive(Debug, Clone)]
28pub enum ColorScheme {
29 Viridis,
31 Thermal,
33 Grayscale,
35 Custom {
37 negative: String,
38 zero: String,
39 positive: String,
40 },
41}
42
43#[derive(Debug, Clone)]
45pub enum RenderingMode {
46 VolumeGradient,
48 PointCloud { point_size: f32 },
50 GradientArrows {
52 arrow_scale: f32,
53 density_factor: f32,
54 },
55}
56
57impl Default for FieldVisualizationConfig {
58 fn default() -> Self {
59 Self {
60 resolution: [32, 32, 32],
61 bounds: [[-2.0, -2.0, -2.0], [2.0, 2.0, 2.0]],
62 color_scheme: ColorScheme::Viridis,
63 opacity: 0.6,
64 mode: RenderingMode::VolumeGradient,
65 }
66 }
67}
68
69#[derive(Debug)]
71pub struct FieldVisualization {
72 pub render_data: Value,
74 pub metadata: FieldMetadata,
76}
77
78#[derive(Debug)]
80pub struct FieldMetadata {
81 pub dimension: usize,
82 pub sample_count: usize,
83 pub value_range: (f32, f32),
84 pub resolution: [usize; 3],
85 pub bounds: [[f32; 3]; 2],
86}
87
88pub fn visualize_field<B: Backend, const N: usize>(
90 field: &dyn ScalarField<N, B>,
91 config: &FieldVisualizationConfig,
92) -> Result<FieldVisualization, Box<dyn Error>> {
93 validate_config::<N>(config)?;
95
96 let samples = sample_field::<B, N>(field, config)?;
98
99 let render_data = match &config.mode {
101 RenderingMode::VolumeGradient => VolumeRenderer::new(config).render::<N>(&samples)?,
102 RenderingMode::PointCloud { point_size } => {
103 PointCloudRenderer::new(config, *point_size).render::<N>(&samples)?
104 }
105 RenderingMode::GradientArrows {
106 arrow_scale,
107 density_factor,
108 } => GradientArrowRenderer::new(config, *arrow_scale, *density_factor)
109 .render::<B, N>(field, &samples)?,
110 };
111
112 let metadata = FieldMetadata {
113 dimension: N,
114 sample_count: samples.values.len(),
115 value_range: (samples.min_value, samples.max_value),
116 resolution: config.resolution,
117 bounds: config.bounds,
118 };
119
120 Ok(FieldVisualization {
121 render_data,
122 metadata,
123 })
124}
125
126#[derive(Debug)]
128struct FieldSamples {
129 positions: Vec<f32>,
130 values: Vec<f32>,
131 min_value: f32,
132 max_value: f32,
133}
134
135fn sample_field<B: Backend, const N: usize>(
137 field: &dyn ScalarField<N, B>,
138 config: &FieldVisualizationConfig,
139) -> Result<FieldSamples, Box<dyn Error>> {
140 let grid_points = generate_grid::<B, N>(config)?;
141 let field_values = field.evaluate(grid_points.clone());
142
143 let positions_data = grid_points.to_data();
145 let values_data = field_values.to_data();
146
147 let positions: Vec<f32> = positions_data.iter::<f32>().collect();
148 let values: Vec<f32> = values_data.iter::<f32>().collect();
149
150 let min_value = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
152 let max_value = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
153
154 Ok(FieldSamples {
155 positions,
156 values,
157 min_value,
158 max_value,
159 })
160}
161
162fn generate_grid<B: Backend, const N: usize>(
164 config: &FieldVisualizationConfig,
165) -> Result<Tensor<B, 2>, Box<dyn Error>> {
166 let device = B::Device::default();
167 let total_points = config.resolution.iter().take(N).product::<usize>();
168 let mut points = Vec::with_capacity(total_points * N);
169
170 match N {
171 2 => generate_2d_grid(&mut points, config),
172 3 => generate_3d_grid(&mut points, config),
173 _ => return Err("Only 2D and 3D fields are supported".into()),
174 }
175
176 let tensor = Tensor::<B, 1>::from_data(points.as_slice(), &device);
177 Ok(tensor.reshape([total_points, N]))
178}
179
180fn generate_2d_grid(points: &mut Vec<f32>, config: &FieldVisualizationConfig) {
181 let [nx, ny, _] = config.resolution;
182 let [[x_min, y_min, _], [x_max, y_max, _]] = config.bounds;
183
184 for j in 0..ny {
185 for i in 0..nx {
186 let x = x_min + (i as f32 / (nx - 1) as f32) * (x_max - x_min);
187 let y = y_min + (j as f32 / (ny - 1) as f32) * (y_max - y_min);
188 points.extend_from_slice(&[x, y]);
189 }
190 }
191}
192
193fn generate_3d_grid(points: &mut Vec<f32>, config: &FieldVisualizationConfig) {
194 let [nx, ny, nz] = config.resolution;
195 let [[x_min, y_min, z_min], [x_max, y_max, z_max]] = config.bounds;
196
197 for k in 0..nz {
198 for j in 0..ny {
199 for i in 0..nx {
200 let x = x_min + (i as f32 / (nx - 1) as f32) * (x_max - x_min);
201 let y = y_min + (j as f32 / (ny - 1) as f32) * (y_max - y_min);
202 let z = z_min + (k as f32 / (nz - 1) as f32) * (z_max - z_min);
203 points.extend_from_slice(&[x, y, z]);
204 }
205 }
206 }
207}
208
209struct VolumeRenderer<'a> {
211 config: &'a FieldVisualizationConfig,
212}
213
214impl<'a> VolumeRenderer<'a> {
215 fn new(config: &'a FieldVisualizationConfig) -> Self {
216 Self { config }
217 }
218
219 fn render<const N: usize>(&self, samples: &FieldSamples) -> Result<Value, Box<dyn Error>> {
220 let color_mapper = ColorMapper::new(&self.config.color_scheme);
221 let mut texture_data = Vec::new();
222
223 for &value in samples.values.iter() {
224 if !value.is_finite() {
225 continue;
226 }
227
228 let (r, g, b) = color_mapper.map_value(value, samples.min_value, samples.max_value);
229 let alpha = self.compute_alpha(value, samples.min_value, samples.max_value);
230
231 texture_data.push(json!({
232 "r": r,
233 "g": g,
234 "b": b,
235 "a": alpha,
236 "value": value
237 }));
238 }
239
240 Ok(json!({
241 "type": "volume",
242 "mode": "gradient",
243 "resolution": self.config.resolution,
244 "bounds": self.config.bounds,
245 "textureData": texture_data,
246 "opacity": self.config.opacity,
247 "colorScheme": format!("{:?}", self.config.color_scheme)
248 }))
249 }
250
251 fn compute_alpha(&self, value: f32, min_val: f32, max_val: f32) -> u8 {
252 let range = max_val - min_val;
253 if range <= 0.0 {
254 return (self.config.opacity * 255.0) as u8;
255 }
256
257 let abs_value = value.abs();
259 let max_abs = (min_val.abs()).max(max_val.abs());
260
261 if max_abs <= 0.0 {
262 return (self.config.opacity * 255.0) as u8;
263 }
264
265 let distance_factor = 1.0 - (abs_value / max_abs).min(1.0);
268 let smooth_factor = distance_factor * distance_factor * distance_factor * distance_factor; let alpha = self.config.opacity * smooth_factor;
271 (alpha * 255.0) as u8
272 }
273}
274
275struct PointCloudRenderer<'a> {
277 config: &'a FieldVisualizationConfig,
278 point_size: f32,
279}
280
281impl<'a> PointCloudRenderer<'a> {
282 fn new(config: &'a FieldVisualizationConfig, point_size: f32) -> Self {
283 Self { config, point_size }
284 }
285
286 fn render<const N: usize>(&self, samples: &FieldSamples) -> Result<Value, Box<dyn Error>> {
287 let color_mapper = ColorMapper::new(&self.config.color_scheme);
288 let mut points = Vec::new();
289
290 for (i, &value) in samples.values.iter().enumerate() {
291 if !value.is_finite() {
292 continue;
293 }
294
295 let position = self.extract_position::<N>(&samples.positions, i);
296 let color = color_mapper.map_to_hex(value, samples.min_value, samples.max_value)?;
297
298 points.push(json!({
299 "position": position,
300 "value": value,
301 "color": color,
302 "size": self.point_size
303 }));
304 }
305
306 Ok(json!({
307 "type": "points",
308 "data": points
309 }))
310 }
311
312 fn extract_position<const N: usize>(&self, positions: &[f32], index: usize) -> [f32; 3] {
313 match N {
314 2 => [positions[index * 2], positions[index * 2 + 1], 0.0],
315 3 => [
316 positions[index * 3],
317 positions[index * 3 + 1],
318 positions[index * 3 + 2],
319 ],
320 _ => [0.0, 0.0, 0.0],
321 }
322 }
323}
324
325struct GradientArrowRenderer<'a> {
327 config: &'a FieldVisualizationConfig,
328 arrow_scale: f32,
329 density_factor: f32,
330}
331
332impl<'a> GradientArrowRenderer<'a> {
333 fn new(config: &'a FieldVisualizationConfig, arrow_scale: f32, density_factor: f32) -> Self {
334 Self {
335 config,
336 arrow_scale,
337 density_factor,
338 }
339 }
340
341 fn render<B: Backend, const N: usize>(
342 &self,
343 field: &dyn ScalarField<N, B>,
344 samples: &FieldSamples,
345 ) -> Result<Value, Box<dyn Error>> {
346 let mut arrows = Vec::new();
347 let _color_mapper = ColorMapper::new(&self.config.color_scheme);
348
349 let gradients = self.compute_gradients::<B, N>(field, samples)?;
351
352 for (i, (&value, gradient)) in samples.values.iter().zip(gradients.iter()).enumerate() {
353 if !value.is_finite() {
354 continue;
355 }
356
357 let grad_magnitude = gradient.iter().map(|x| x * x).sum::<f32>().sqrt();
359 if grad_magnitude < 1e-6 {
360 continue;
361 }
362
363 let abs_value = value.abs();
365 let density_threshold = 1.0 - (abs_value * self.density_factor).min(0.8);
367
368 let position = self.extract_position::<N>(&samples.positions, i);
370 let hash = ((position[0] * 73.0
371 + position[1] * 151.0
372 + position.get(2).unwrap_or(&0.0) * 233.0)
373 .abs()
374 * 1000.0) as u32;
375 let random_val = (hash % 1000) as f32 / 1000.0;
376
377 if random_val > density_threshold {
378 continue;
379 }
380
381 let arrow_length = self.arrow_scale;
383
384 let direction_sign = if value > 0.0 { -1.0 } else { 1.0 };
386 let direction: Vec<f32> = gradient
387 .iter()
388 .map(|x| direction_sign * x / grad_magnitude)
389 .collect();
390
391 let mut end_position = position;
393 for (i, &dir) in direction.iter().enumerate().take(N) {
394 if i < 3 {
395 end_position[i] += dir * arrow_length;
396 }
397 }
398
399 let field_proximity = 1.0 / (1.0 + abs_value * 1.5); let field_proximity = field_proximity.powf(0.5); let gray_value = (128.0 + field_proximity * 127.0) as u8; let color = format!("#{:02x}{:02x}{:02x}", gray_value, gray_value, gray_value);
408
409 arrows.push(json!({
410 "start": position,
411 "end": end_position,
412 "direction": direction,
413 "magnitude": grad_magnitude,
414 "fieldValue": value,
415 "color": color,
416 "length": arrow_length,
417 "opacity": field_proximity }));
419 }
420
421 Ok(json!({
422 "type": "arrows",
423 "data": arrows,
424 "arrowScale": self.arrow_scale,
425 "densityFactor": self.density_factor
426 }))
427 }
428
429 fn compute_gradients<B: Backend, const N: usize>(
430 &self,
431 field: &dyn ScalarField<N, B>,
432 samples: &FieldSamples,
433 ) -> Result<Vec<Vec<f32>>, Box<dyn Error>> {
434 let device = B::Device::default();
435 let epsilon = 1e-4;
436
437 let mut gradients = Vec::new();
438
439 for (i, _) in samples.values.iter().enumerate() {
441 let base_pos = self.extract_position::<N>(&samples.positions, i);
442
443 let mut gradient = Vec::new();
444
445 for dim in 0..N {
446 let mut pos_forward = base_pos;
448 let mut pos_backward = base_pos;
449
450 if dim < 3 {
451 pos_forward[dim] += epsilon;
452 pos_backward[dim] -= epsilon;
453 }
454
455 let forward_coords: Vec<f32> = [
457 pos_forward[0],
458 pos_forward[1],
459 pos_forward.get(2).copied().unwrap_or(0.0),
460 ]
461 .iter()
462 .take(N)
463 .copied()
464 .collect();
465 let backward_coords: Vec<f32> = [
466 pos_backward[0],
467 pos_backward[1],
468 pos_backward.get(2).copied().unwrap_or(0.0),
469 ]
470 .iter()
471 .take(N)
472 .copied()
473 .collect();
474
475 let tensor_forward =
476 Tensor::<B, 1>::from_data(forward_coords.as_slice(), &device).reshape([1, N]);
477 let tensor_backward =
478 Tensor::<B, 1>::from_data(backward_coords.as_slice(), &device).reshape([1, N]);
479
480 let val_forward = field
481 .evaluate(tensor_forward)
482 .to_data()
483 .iter::<f32>()
484 .next()
485 .unwrap_or(0.0);
486 let val_backward = field
487 .evaluate(tensor_backward)
488 .to_data()
489 .iter::<f32>()
490 .next()
491 .unwrap_or(0.0);
492
493 let partial_derivative = (val_forward - val_backward) / (2.0 * epsilon);
494 gradient.push(partial_derivative);
495 }
496
497 gradients.push(gradient);
498 }
499
500 Ok(gradients)
501 }
502
503 fn extract_position<const N: usize>(&self, positions: &[f32], index: usize) -> [f32; 3] {
504 match N {
505 2 => [positions[index * 2], positions[index * 2 + 1], 0.0],
506 3 => [
507 positions[index * 3],
508 positions[index * 3 + 1],
509 positions[index * 3 + 2],
510 ],
511 _ => [0.0, 0.0, 0.0],
512 }
513 }
514}
515
516struct ColorMapper {
518 scheme: ColorScheme,
519}
520
521impl ColorMapper {
522 fn new(scheme: &ColorScheme) -> Self {
523 Self {
524 scheme: scheme.clone(),
525 }
526 }
527
528 fn map_value(&self, value: f32, min_val: f32, max_val: f32) -> (u8, u8, u8) {
529 let normalized = if max_val > min_val {
530 ((value - min_val) / (max_val - min_val)).clamp(0.0, 1.0)
531 } else {
532 0.5
533 };
534
535 match &self.scheme {
536 ColorScheme::Viridis => viridis_colormap(normalized),
537 ColorScheme::Thermal => thermal_colormap(normalized),
538 ColorScheme::Grayscale => {
539 let gray = (normalized * 255.0) as u8;
540 (gray, gray, gray)
541 }
542 ColorScheme::Custom {
543 negative,
544 zero,
545 positive,
546 } => self.custom_gradient(value, min_val, max_val, negative, zero, positive),
547 }
548 }
549
550 fn map_to_hex(&self, value: f32, min_val: f32, max_val: f32) -> Result<String, Box<dyn Error>> {
551 let (r, g, b) = self.map_value(value, min_val, max_val);
552 Ok(format!("#{:02x}{:02x}{:02x}", r, g, b))
553 }
554
555 fn custom_gradient(
556 &self,
557 value: f32,
558 min_val: f32,
559 max_val: f32,
560 neg_color: &str,
561 zero_color: &str,
562 pos_color: &str,
563 ) -> (u8, u8, u8) {
564 if value < 0.0 && min_val < 0.0 {
566 let t = (value / min_val).clamp(0.0, 1.0);
567 interpolate_hex_colors(neg_color, zero_color, t).unwrap_or((128, 128, 128))
568 } else if value > 0.0 && max_val > 0.0 {
569 let t = (value / max_val).clamp(0.0, 1.0);
570 interpolate_hex_colors(zero_color, pos_color, t).unwrap_or((128, 128, 128))
571 } else {
572 parse_hex_color(zero_color).unwrap_or((128, 128, 128))
573 }
574 }
575}
576
577fn viridis_colormap(t: f32) -> (u8, u8, u8) {
579 let t = t.clamp(0.0, 1.0);
580
581 let control_points = [
583 (68, 1, 84), (59, 82, 139), (33, 145, 140), (94, 201, 98), (253, 231, 37), ];
589
590 interpolate_control_points(&control_points, t)
591}
592
593fn thermal_colormap(t: f32) -> (u8, u8, u8) {
595 let t = t.clamp(0.0, 1.0);
596 let r = (t * 255.0) as u8;
597 let b = ((1.0 - t) * 255.0) as u8;
598 (r, 0, b)
599}
600
601fn interpolate_control_points(points: &[(u8, u8, u8)], t: f32) -> (u8, u8, u8) {
603 if points.is_empty() {
604 return (0, 0, 0);
605 }
606
607 let scaled = t * (points.len() - 1) as f32;
608 let index = scaled.floor() as usize;
609 let frac = scaled - index as f32;
610
611 if index >= points.len() - 1 {
612 return points[points.len() - 1];
613 }
614
615 let (r1, g1, b1) = points[index];
616 let (r2, g2, b2) = points[index + 1];
617
618 let r = (r1 as f32 + (r2 as f32 - r1 as f32) * frac) as u8;
619 let g = (g1 as f32 + (g2 as f32 - g1 as f32) * frac) as u8;
620 let b = (b1 as f32 + (b2 as f32 - b1 as f32) * frac) as u8;
621
622 (r, g, b)
623}
624
625fn parse_hex_color(hex: &str) -> Result<(u8, u8, u8), Box<dyn Error>> {
627 if !hex.starts_with('#') || hex.len() != 7 {
628 return Err("Invalid hex color format".into());
629 }
630
631 let r = u8::from_str_radix(&hex[1..3], 16)?;
632 let g = u8::from_str_radix(&hex[3..5], 16)?;
633 let b = u8::from_str_radix(&hex[5..7], 16)?;
634
635 Ok((r, g, b))
636}
637
638fn interpolate_hex_colors(
640 color1: &str,
641 color2: &str,
642 t: f32,
643) -> Result<(u8, u8, u8), Box<dyn Error>> {
644 let (r1, g1, b1) = parse_hex_color(color1)?;
645 let (r2, g2, b2) = parse_hex_color(color2)?;
646
647 let t = t.clamp(0.0, 1.0);
648 let r = (r1 as f32 + (r2 as f32 - r1 as f32) * t) as u8;
649 let g = (g1 as f32 + (g2 as f32 - g1 as f32) * t) as u8;
650 let b = (b1 as f32 + (b2 as f32 - b1 as f32) * t) as u8;
651
652 Ok((r, g, b))
653}
654
655fn validate_config<const N: usize>(
657 config: &FieldVisualizationConfig,
658) -> Result<(), Box<dyn Error>> {
659 if N > 3 {
660 return Err("Field visualization only supports 2D and 3D fields".into());
661 }
662
663 if config.resolution.iter().any(|&r| r < 2) {
664 return Err("Resolution must be at least 2 in each dimension".into());
665 }
666
667 if config.opacity < 0.0 || config.opacity > 1.0 {
668 return Err("Opacity must be between 0.0 and 1.0".into());
669 }
670
671 Ok(())
672}
673
674#[cfg(test)]
675mod tests {
676 use super::*;
677 use crate::csg::fields::Field2D;
678 use backend_macro::with_backend;
679
680 #[with_backend]
681 #[test]
682 fn test_field_visualization() {
683 let device = device();
684 let circle = Field2D::<Backend>::circle(1.0, device);
685
686 let config = FieldVisualizationConfig {
687 resolution: [16, 16, 2],
688 bounds: [[-2.0, -2.0, 0.0], [2.0, 2.0, 0.0]],
689 color_scheme: ColorScheme::Viridis,
690 opacity: 0.8,
691 mode: RenderingMode::VolumeGradient,
692 };
693
694 let result = visualize_field(&circle, &config);
695 println!("{:?}", result);
696
697 assert!(result.is_ok());
698
699 let visualization = result.unwrap();
700 assert_eq!(visualization.metadata.dimension, 2);
701 assert_eq!(visualization.metadata.sample_count, 256); }
703
704 #[test]
705 fn test_viridis_colormap() {
706 let (r, g, b) = viridis_colormap(0.0);
707 assert_eq!((r, g, b), (68, 1, 84)); let (r, g, b) = viridis_colormap(1.0);
710 assert_eq!((r, g, b), (253, 231, 37)); }
712
713 #[test]
714 fn test_config_validation() {
715 let config = FieldVisualizationConfig {
716 opacity: 1.5,
717 ..Default::default()
718 };
719
720 let result = validate_config::<2>(&config);
721 assert!(result.is_err());
722 }
723}