crater/serde/threejs/
field_visualization.rs

1//! Scalar field visualization for Three.js
2//!
3//! This module provides high-quality, performant visualization of scalar fields
4//! with support for continuous gradients, volume rendering, and multiple color schemes.
5
6use crate::csg::prelude::*;
7use burn::prelude::*;
8use serde_json::{Value, json};
9use std::error::Error;
10
11/// Configuration for scalar field visualization
12#[derive(Debug, Clone)]
13pub struct FieldVisualizationConfig {
14    /// Sampling resolution in each dimension [x, y, z]
15    pub resolution: [usize; 3],
16    /// Spatial bounds [[min_x, min_y, min_z], [max_x, max_y, max_z]]
17    pub bounds: [[f32; 3]; 2],
18    /// Color mapping scheme
19    pub color_scheme: ColorScheme,
20    /// Overall opacity for the visualization
21    pub opacity: f32,
22    /// Rendering method
23    pub mode: RenderingMode,
24}
25
26/// Color schemes for field visualization
27#[derive(Debug, Clone)]
28pub enum ColorScheme {
29    /// Perceptually uniform viridis color map (recommended)
30    Viridis,
31    /// Classic blue-to-red heat map
32    Thermal,
33    /// Grayscale mapping
34    Grayscale,
35    /// Custom three-point gradient
36    Custom {
37        negative: String,
38        zero: String,
39        positive: String,
40    },
41}
42
43/// Field rendering methods
44#[derive(Debug, Clone)]
45pub enum RenderingMode {
46    /// Continuous volume gradient (for smooth fields)
47    VolumeGradient,
48    /// Discrete point cloud
49    PointCloud { point_size: f32 },
50    /// Gradient vector arrows (direction and magnitude visualization)
51    GradientArrows {
52        arrow_scale: f32,
53        density_factor: f32,
54    },
55}
56
57impl Default for FieldVisualizationConfig {
58    fn default() -> Self {
59        Self {
60            resolution: [32, 32, 32],
61            bounds: [[-2.0, -2.0, -2.0], [2.0, 2.0, 2.0]],
62            color_scheme: ColorScheme::Viridis,
63            opacity: 0.6,
64            mode: RenderingMode::VolumeGradient,
65        }
66    }
67}
68
69/// Complete field visualization data package
70#[derive(Debug)]
71pub struct FieldVisualization {
72    /// Three.js-compatible rendering data
73    pub render_data: Value,
74    /// Field statistics and metadata
75    pub metadata: FieldMetadata,
76}
77
78/// Statistical information about the field
79#[derive(Debug)]
80pub struct FieldMetadata {
81    pub dimension: usize,
82    pub sample_count: usize,
83    pub value_range: (f32, f32),
84    pub resolution: [usize; 3],
85    pub bounds: [[f32; 3]; 2],
86}
87
88/// Main entry point for field visualization
89pub fn visualize_field<B: Backend, const N: usize>(
90    field: &dyn ScalarField<N, B>,
91    config: &FieldVisualizationConfig,
92) -> Result<FieldVisualization, Box<dyn Error>> {
93    // Validate configuration
94    validate_config::<N>(config)?;
95
96    // Sample the field on a regular grid
97    let samples = sample_field::<B, N>(field, config)?;
98
99    // Generate visualization based on rendering mode
100    let render_data = match &config.mode {
101        RenderingMode::VolumeGradient => VolumeRenderer::new(config).render::<N>(&samples)?,
102        RenderingMode::PointCloud { point_size } => {
103            PointCloudRenderer::new(config, *point_size).render::<N>(&samples)?
104        }
105        RenderingMode::GradientArrows {
106            arrow_scale,
107            density_factor,
108        } => GradientArrowRenderer::new(config, *arrow_scale, *density_factor)
109            .render::<B, N>(field, &samples)?,
110    };
111
112    let metadata = FieldMetadata {
113        dimension: N,
114        sample_count: samples.values.len(),
115        value_range: (samples.min_value, samples.max_value),
116        resolution: config.resolution,
117        bounds: config.bounds,
118    };
119
120    Ok(FieldVisualization {
121        render_data,
122        metadata,
123    })
124}
125
126/// Field sampling results
127#[derive(Debug)]
128struct FieldSamples {
129    positions: Vec<f32>,
130    values: Vec<f32>,
131    min_value: f32,
132    max_value: f32,
133}
134
135/// Sample the scalar field on a regular grid
136fn sample_field<B: Backend, const N: usize>(
137    field: &dyn ScalarField<N, B>,
138    config: &FieldVisualizationConfig,
139) -> Result<FieldSamples, Box<dyn Error>> {
140    let grid_points = generate_grid::<B, N>(config)?;
141    let field_values = field.evaluate(grid_points.clone());
142
143    // Convert to host data
144    let positions_data = grid_points.to_data();
145    let values_data = field_values.to_data();
146
147    let positions: Vec<f32> = positions_data.iter::<f32>().collect();
148    let values: Vec<f32> = values_data.iter::<f32>().collect();
149
150    // Compute value statistics
151    let min_value = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
152    let max_value = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
153
154    Ok(FieldSamples {
155        positions,
156        values,
157        min_value,
158        max_value,
159    })
160}
161
162/// Generate regular sampling grid
163fn generate_grid<B: Backend, const N: usize>(
164    config: &FieldVisualizationConfig,
165) -> Result<Tensor<B, 2>, Box<dyn Error>> {
166    let device = B::Device::default();
167    let total_points = config.resolution.iter().take(N).product::<usize>();
168    let mut points = Vec::with_capacity(total_points * N);
169
170    match N {
171        2 => generate_2d_grid(&mut points, config),
172        3 => generate_3d_grid(&mut points, config),
173        _ => return Err("Only 2D and 3D fields are supported".into()),
174    }
175
176    let tensor = Tensor::<B, 1>::from_data(points.as_slice(), &device);
177    Ok(tensor.reshape([total_points, N]))
178}
179
180fn generate_2d_grid(points: &mut Vec<f32>, config: &FieldVisualizationConfig) {
181    let [nx, ny, _] = config.resolution;
182    let [[x_min, y_min, _], [x_max, y_max, _]] = config.bounds;
183
184    for j in 0..ny {
185        for i in 0..nx {
186            let x = x_min + (i as f32 / (nx - 1) as f32) * (x_max - x_min);
187            let y = y_min + (j as f32 / (ny - 1) as f32) * (y_max - y_min);
188            points.extend_from_slice(&[x, y]);
189        }
190    }
191}
192
193fn generate_3d_grid(points: &mut Vec<f32>, config: &FieldVisualizationConfig) {
194    let [nx, ny, nz] = config.resolution;
195    let [[x_min, y_min, z_min], [x_max, y_max, z_max]] = config.bounds;
196
197    for k in 0..nz {
198        for j in 0..ny {
199            for i in 0..nx {
200                let x = x_min + (i as f32 / (nx - 1) as f32) * (x_max - x_min);
201                let y = y_min + (j as f32 / (ny - 1) as f32) * (y_max - y_min);
202                let z = z_min + (k as f32 / (nz - 1) as f32) * (z_max - z_min);
203                points.extend_from_slice(&[x, y, z]);
204            }
205        }
206    }
207}
208
209/// Volume gradient renderer for continuous field visualization
210struct VolumeRenderer<'a> {
211    config: &'a FieldVisualizationConfig,
212}
213
214impl<'a> VolumeRenderer<'a> {
215    fn new(config: &'a FieldVisualizationConfig) -> Self {
216        Self { config }
217    }
218
219    fn render<const N: usize>(&self, samples: &FieldSamples) -> Result<Value, Box<dyn Error>> {
220        let color_mapper = ColorMapper::new(&self.config.color_scheme);
221        let mut texture_data = Vec::new();
222
223        for &value in samples.values.iter() {
224            if !value.is_finite() {
225                continue;
226            }
227
228            let (r, g, b) = color_mapper.map_value(value, samples.min_value, samples.max_value);
229            let alpha = self.compute_alpha(value, samples.min_value, samples.max_value);
230
231            texture_data.push(json!({
232                "r": r,
233                "g": g,
234                "b": b,
235                "a": alpha,
236                "value": value
237            }));
238        }
239
240        Ok(json!({
241            "type": "volume",
242            "mode": "gradient",
243            "resolution": self.config.resolution,
244            "bounds": self.config.bounds,
245            "textureData": texture_data,
246            "opacity": self.config.opacity,
247            "colorScheme": format!("{:?}", self.config.color_scheme)
248        }))
249    }
250
251    fn compute_alpha(&self, value: f32, min_val: f32, max_val: f32) -> u8 {
252        let range = max_val - min_val;
253        if range <= 0.0 {
254            return (self.config.opacity * 255.0) as u8;
255        }
256
257        // Simple, smooth gradient based on distance from zero isosurface
258        let abs_value = value.abs();
259        let max_abs = (min_val.abs()).max(max_val.abs());
260
261        if max_abs <= 0.0 {
262            return (self.config.opacity * 255.0) as u8;
263        }
264
265        // Fade out as we get further from the isosurface (value = 0)
266        // This creates a subtle visualization focused on the boundary
267        let distance_factor = 1.0 - (abs_value / max_abs).min(1.0);
268        let smooth_factor = distance_factor * distance_factor * distance_factor * distance_factor; // Quartic falloff for smoother appearance
269
270        let alpha = self.config.opacity * smooth_factor;
271        (alpha * 255.0) as u8
272    }
273}
274
275/// Point cloud renderer for discrete visualization
276struct PointCloudRenderer<'a> {
277    config: &'a FieldVisualizationConfig,
278    point_size: f32,
279}
280
281impl<'a> PointCloudRenderer<'a> {
282    fn new(config: &'a FieldVisualizationConfig, point_size: f32) -> Self {
283        Self { config, point_size }
284    }
285
286    fn render<const N: usize>(&self, samples: &FieldSamples) -> Result<Value, Box<dyn Error>> {
287        let color_mapper = ColorMapper::new(&self.config.color_scheme);
288        let mut points = Vec::new();
289
290        for (i, &value) in samples.values.iter().enumerate() {
291            if !value.is_finite() {
292                continue;
293            }
294
295            let position = self.extract_position::<N>(&samples.positions, i);
296            let color = color_mapper.map_to_hex(value, samples.min_value, samples.max_value)?;
297
298            points.push(json!({
299                "position": position,
300                "value": value,
301                "color": color,
302                "size": self.point_size
303            }));
304        }
305
306        Ok(json!({
307            "type": "points",
308            "data": points
309        }))
310    }
311
312    fn extract_position<const N: usize>(&self, positions: &[f32], index: usize) -> [f32; 3] {
313        match N {
314            2 => [positions[index * 2], positions[index * 2 + 1], 0.0],
315            3 => [
316                positions[index * 3],
317                positions[index * 3 + 1],
318                positions[index * 3 + 2],
319            ],
320            _ => [0.0, 0.0, 0.0],
321        }
322    }
323}
324
325/// Gradient arrow renderer for vector field visualization
326struct GradientArrowRenderer<'a> {
327    config: &'a FieldVisualizationConfig,
328    arrow_scale: f32,
329    density_factor: f32,
330}
331
332impl<'a> GradientArrowRenderer<'a> {
333    fn new(config: &'a FieldVisualizationConfig, arrow_scale: f32, density_factor: f32) -> Self {
334        Self {
335            config,
336            arrow_scale,
337            density_factor,
338        }
339    }
340
341    fn render<B: Backend, const N: usize>(
342        &self,
343        field: &dyn ScalarField<N, B>,
344        samples: &FieldSamples,
345    ) -> Result<Value, Box<dyn Error>> {
346        let mut arrows = Vec::new();
347        let _color_mapper = ColorMapper::new(&self.config.color_scheme);
348
349        // Compute gradients numerically
350        let gradients = self.compute_gradients::<B, N>(field, samples)?;
351
352        for (i, (&value, gradient)) in samples.values.iter().zip(gradients.iter()).enumerate() {
353            if !value.is_finite() {
354                continue;
355            }
356
357            // Skip points with very small gradient magnitude (no interesting direction)
358            let grad_magnitude = gradient.iter().map(|x| x * x).sum::<f32>().sqrt();
359            if grad_magnitude < 1e-6 {
360                continue;
361            }
362
363            // Density filtering: favor points near isosurface (low absolute field values)
364            let abs_value = value.abs();
365            // Invert the density threshold - higher density near zero field values
366            let density_threshold = 1.0 - (abs_value * self.density_factor).min(0.8);
367
368            // Use deterministic sampling based on position for consistent results
369            let position = self.extract_position::<N>(&samples.positions, i);
370            let hash = ((position[0] * 73.0
371                + position[1] * 151.0
372                + position.get(2).unwrap_or(&0.0) * 233.0)
373                .abs()
374                * 1000.0) as u32;
375            let random_val = (hash % 1000) as f32 / 1000.0;
376
377            if random_val > density_threshold {
378                continue;
379            }
380
381            // Fixed arrow length - all arrows same size, only direction varies
382            let arrow_length = self.arrow_scale;
383
384            // Normalize gradient direction and flip if field is positive (point toward zero)
385            let direction_sign = if value > 0.0 { -1.0 } else { 1.0 };
386            let direction: Vec<f32> = gradient
387                .iter()
388                .map(|x| direction_sign * x / grad_magnitude)
389                .collect();
390
391            // Arrow endpoint
392            let mut end_position = position;
393            for (i, &dir) in direction.iter().enumerate().take(N) {
394                if i < 3 {
395                    end_position[i] += dir * arrow_length;
396                }
397            }
398
399            // Color based on field value proximity to zero (brighter near isosurface)
400            // Use inverse of absolute field value so arrows are brightest near zero
401            // More aggressive scaling to make more arrows bright
402            let field_proximity = 1.0 / (1.0 + abs_value * 1.5); // Higher values = closer to zero
403            let field_proximity = field_proximity.powf(0.5); // Square root for more aggressive scaling
404
405            // Simple gray to white interpolation based on proximity
406            let gray_value = (128.0 + field_proximity * 127.0) as u8; // 128-255 range
407            let color = format!("#{:02x}{:02x}{:02x}", gray_value, gray_value, gray_value);
408
409            arrows.push(json!({
410                "start": position,
411                "end": end_position,
412                "direction": direction,
413                "magnitude": grad_magnitude,
414                "fieldValue": value,
415                "color": color,
416                "length": arrow_length,
417                "opacity": field_proximity  // Pass proximity for opacity control
418            }));
419        }
420
421        Ok(json!({
422            "type": "arrows",
423            "data": arrows,
424            "arrowScale": self.arrow_scale,
425            "densityFactor": self.density_factor
426        }))
427    }
428
429    fn compute_gradients<B: Backend, const N: usize>(
430        &self,
431        field: &dyn ScalarField<N, B>,
432        samples: &FieldSamples,
433    ) -> Result<Vec<Vec<f32>>, Box<dyn Error>> {
434        let device = B::Device::default();
435        let epsilon = 1e-4;
436
437        let mut gradients = Vec::new();
438
439        // Compute numerical gradients using finite differences
440        for (i, _) in samples.values.iter().enumerate() {
441            let base_pos = self.extract_position::<N>(&samples.positions, i);
442
443            let mut gradient = Vec::new();
444
445            for dim in 0..N {
446                // Forward difference
447                let mut pos_forward = base_pos;
448                let mut pos_backward = base_pos;
449
450                if dim < 3 {
451                    pos_forward[dim] += epsilon;
452                    pos_backward[dim] -= epsilon;
453                }
454
455                // Evaluate field at forward and backward positions
456                let forward_coords: Vec<f32> = [
457                    pos_forward[0],
458                    pos_forward[1],
459                    pos_forward.get(2).copied().unwrap_or(0.0),
460                ]
461                .iter()
462                .take(N)
463                .copied()
464                .collect();
465                let backward_coords: Vec<f32> = [
466                    pos_backward[0],
467                    pos_backward[1],
468                    pos_backward.get(2).copied().unwrap_or(0.0),
469                ]
470                .iter()
471                .take(N)
472                .copied()
473                .collect();
474
475                let tensor_forward =
476                    Tensor::<B, 1>::from_data(forward_coords.as_slice(), &device).reshape([1, N]);
477                let tensor_backward =
478                    Tensor::<B, 1>::from_data(backward_coords.as_slice(), &device).reshape([1, N]);
479
480                let val_forward = field
481                    .evaluate(tensor_forward)
482                    .to_data()
483                    .iter::<f32>()
484                    .next()
485                    .unwrap_or(0.0);
486                let val_backward = field
487                    .evaluate(tensor_backward)
488                    .to_data()
489                    .iter::<f32>()
490                    .next()
491                    .unwrap_or(0.0);
492
493                let partial_derivative = (val_forward - val_backward) / (2.0 * epsilon);
494                gradient.push(partial_derivative);
495            }
496
497            gradients.push(gradient);
498        }
499
500        Ok(gradients)
501    }
502
503    fn extract_position<const N: usize>(&self, positions: &[f32], index: usize) -> [f32; 3] {
504        match N {
505            2 => [positions[index * 2], positions[index * 2 + 1], 0.0],
506            3 => [
507                positions[index * 3],
508                positions[index * 3 + 1],
509                positions[index * 3 + 2],
510            ],
511            _ => [0.0, 0.0, 0.0],
512        }
513    }
514}
515
516/// Professional color mapping with multiple schemes
517struct ColorMapper {
518    scheme: ColorScheme,
519}
520
521impl ColorMapper {
522    fn new(scheme: &ColorScheme) -> Self {
523        Self {
524            scheme: scheme.clone(),
525        }
526    }
527
528    fn map_value(&self, value: f32, min_val: f32, max_val: f32) -> (u8, u8, u8) {
529        let normalized = if max_val > min_val {
530            ((value - min_val) / (max_val - min_val)).clamp(0.0, 1.0)
531        } else {
532            0.5
533        };
534
535        match &self.scheme {
536            ColorScheme::Viridis => viridis_colormap(normalized),
537            ColorScheme::Thermal => thermal_colormap(normalized),
538            ColorScheme::Grayscale => {
539                let gray = (normalized * 255.0) as u8;
540                (gray, gray, gray)
541            }
542            ColorScheme::Custom {
543                negative,
544                zero,
545                positive,
546            } => self.custom_gradient(value, min_val, max_val, negative, zero, positive),
547        }
548    }
549
550    fn map_to_hex(&self, value: f32, min_val: f32, max_val: f32) -> Result<String, Box<dyn Error>> {
551        let (r, g, b) = self.map_value(value, min_val, max_val);
552        Ok(format!("#{:02x}{:02x}{:02x}", r, g, b))
553    }
554
555    fn custom_gradient(
556        &self,
557        value: f32,
558        min_val: f32,
559        max_val: f32,
560        neg_color: &str,
561        zero_color: &str,
562        pos_color: &str,
563    ) -> (u8, u8, u8) {
564        // Simplified three-color interpolation
565        if value < 0.0 && min_val < 0.0 {
566            let t = (value / min_val).clamp(0.0, 1.0);
567            interpolate_hex_colors(neg_color, zero_color, t).unwrap_or((128, 128, 128))
568        } else if value > 0.0 && max_val > 0.0 {
569            let t = (value / max_val).clamp(0.0, 1.0);
570            interpolate_hex_colors(zero_color, pos_color, t).unwrap_or((128, 128, 128))
571        } else {
572            parse_hex_color(zero_color).unwrap_or((128, 128, 128))
573        }
574    }
575}
576
577/// High-quality viridis color mapping
578fn viridis_colormap(t: f32) -> (u8, u8, u8) {
579    let t = t.clamp(0.0, 1.0);
580
581    // High-fidelity viridis control points
582    let control_points = [
583        (68, 1, 84),    // Deep purple
584        (59, 82, 139),  // Blue
585        (33, 145, 140), // Teal
586        (94, 201, 98),  // Green
587        (253, 231, 37), // Yellow
588    ];
589
590    interpolate_control_points(&control_points, t)
591}
592
593/// Thermal (blue-red) color mapping
594fn thermal_colormap(t: f32) -> (u8, u8, u8) {
595    let t = t.clamp(0.0, 1.0);
596    let r = (t * 255.0) as u8;
597    let b = ((1.0 - t) * 255.0) as u8;
598    (r, 0, b)
599}
600
601/// Interpolate between color control points
602fn interpolate_control_points(points: &[(u8, u8, u8)], t: f32) -> (u8, u8, u8) {
603    if points.is_empty() {
604        return (0, 0, 0);
605    }
606
607    let scaled = t * (points.len() - 1) as f32;
608    let index = scaled.floor() as usize;
609    let frac = scaled - index as f32;
610
611    if index >= points.len() - 1 {
612        return points[points.len() - 1];
613    }
614
615    let (r1, g1, b1) = points[index];
616    let (r2, g2, b2) = points[index + 1];
617
618    let r = (r1 as f32 + (r2 as f32 - r1 as f32) * frac) as u8;
619    let g = (g1 as f32 + (g2 as f32 - g1 as f32) * frac) as u8;
620    let b = (b1 as f32 + (b2 as f32 - b1 as f32) * frac) as u8;
621
622    (r, g, b)
623}
624
625/// Parse hex color string to RGB
626fn parse_hex_color(hex: &str) -> Result<(u8, u8, u8), Box<dyn Error>> {
627    if !hex.starts_with('#') || hex.len() != 7 {
628        return Err("Invalid hex color format".into());
629    }
630
631    let r = u8::from_str_radix(&hex[1..3], 16)?;
632    let g = u8::from_str_radix(&hex[3..5], 16)?;
633    let b = u8::from_str_radix(&hex[5..7], 16)?;
634
635    Ok((r, g, b))
636}
637
638/// Interpolate between two hex colors
639fn interpolate_hex_colors(
640    color1: &str,
641    color2: &str,
642    t: f32,
643) -> Result<(u8, u8, u8), Box<dyn Error>> {
644    let (r1, g1, b1) = parse_hex_color(color1)?;
645    let (r2, g2, b2) = parse_hex_color(color2)?;
646
647    let t = t.clamp(0.0, 1.0);
648    let r = (r1 as f32 + (r2 as f32 - r1 as f32) * t) as u8;
649    let g = (g1 as f32 + (g2 as f32 - g1 as f32) * t) as u8;
650    let b = (b1 as f32 + (b2 as f32 - b1 as f32) * t) as u8;
651
652    Ok((r, g, b))
653}
654
655/// Validate configuration parameters
656fn validate_config<const N: usize>(
657    config: &FieldVisualizationConfig,
658) -> Result<(), Box<dyn Error>> {
659    if N > 3 {
660        return Err("Field visualization only supports 2D and 3D fields".into());
661    }
662
663    if config.resolution.iter().any(|&r| r < 2) {
664        return Err("Resolution must be at least 2 in each dimension".into());
665    }
666
667    if config.opacity < 0.0 || config.opacity > 1.0 {
668        return Err("Opacity must be between 0.0 and 1.0".into());
669    }
670
671    Ok(())
672}
673
674#[cfg(test)]
675mod tests {
676    use super::*;
677    use crate::csg::fields::Field2D;
678    use backend_macro::with_backend;
679
680    #[with_backend]
681    #[test]
682    fn test_field_visualization() {
683        let device = device();
684        let circle = Field2D::<Backend>::circle(1.0, device);
685
686        let config = FieldVisualizationConfig {
687            resolution: [16, 16, 2],
688            bounds: [[-2.0, -2.0, 0.0], [2.0, 2.0, 0.0]],
689            color_scheme: ColorScheme::Viridis,
690            opacity: 0.8,
691            mode: RenderingMode::VolumeGradient,
692        };
693
694        let result = visualize_field(&circle, &config);
695        println!("{:?}", result);
696
697        assert!(result.is_ok());
698
699        let visualization = result.unwrap();
700        assert_eq!(visualization.metadata.dimension, 2);
701        assert_eq!(visualization.metadata.sample_count, 256); // 16x16
702    }
703
704    #[test]
705    fn test_viridis_colormap() {
706        let (r, g, b) = viridis_colormap(0.0);
707        assert_eq!((r, g, b), (68, 1, 84)); // Should be deep purple
708
709        let (r, g, b) = viridis_colormap(1.0);
710        assert_eq!((r, g, b), (253, 231, 37)); // Should be yellow
711    }
712
713    #[test]
714    fn test_config_validation() {
715        let config = FieldVisualizationConfig {
716            opacity: 1.5,
717            ..Default::default()
718        };
719
720        let result = validate_config::<2>(&config);
721        assert!(result.is_err());
722    }
723}