runmat_plot/plots/
surface.rs

1//! 3D surface plot implementation
2//!
3//! High-performance GPU-accelerated 3D surface rendering.
4
5use crate::core::{BoundingBox, DrawCall, Material, PipelineType, RenderData, Vertex};
6use glam::{Vec3, Vec4};
7
8/// High-performance GPU-accelerated 3D surface plot
9#[derive(Debug, Clone)]
10pub struct SurfacePlot {
11    /// Grid data (Z values at X,Y coordinates)
12    pub x_data: Vec<f64>,
13    pub y_data: Vec<f64>,
14    pub z_data: Vec<Vec<f64>>, // z_data[i][j] = Z value at (x_data[i], y_data[j])
15
16    /// Surface properties
17    pub colormap: ColorMap,
18    pub shading_mode: ShadingMode,
19    pub wireframe: bool,
20    pub alpha: f32,
21
22    /// Lighting and material
23    pub lighting_enabled: bool,
24    pub ambient_strength: f32,
25    pub diffuse_strength: f32,
26    pub specular_strength: f32,
27    pub shininess: f32,
28
29    /// Metadata
30    pub label: Option<String>,
31    pub visible: bool,
32
33    /// Generated rendering data (cached)
34    vertices: Option<Vec<Vertex>>,
35    indices: Option<Vec<u32>>,
36    bounds: Option<BoundingBox>,
37    dirty: bool,
38}
39
40/// Color mapping schemes
41#[derive(Debug, Clone, Copy, PartialEq)]
42pub enum ColorMap {
43    /// MATLAB-compatible colormaps
44    Jet,
45    Hot,
46    Cool,
47    Spring,
48    Summer,
49    Autumn,
50    Winter,
51    Gray,
52    Bone,
53    Copper,
54    Pink,
55    Lines,
56
57    /// Scientific colormaps
58    Viridis,
59    Plasma,
60    Inferno,
61    Magma,
62    Turbo,
63
64    /// Perceptually uniform
65    Parula,
66
67    /// Custom color ranges
68    Custom(Vec4, Vec4), // (min_color, max_color)
69}
70
71/// Surface shading modes
72#[derive(Debug, Clone, Copy, PartialEq)]
73pub enum ShadingMode {
74    /// Flat shading (per-face normals)
75    Flat,
76    /// Smooth shading (interpolated normals)
77    Smooth,
78    /// Faceted (flat with visible edges)
79    Faceted,
80    /// No shading (just color mapping)
81    None,
82}
83
84impl Default for ColorMap {
85    fn default() -> Self {
86        Self::Viridis
87    }
88}
89
90impl Default for ShadingMode {
91    fn default() -> Self {
92        Self::Smooth
93    }
94}
95
96impl SurfacePlot {
97    /// Create a new surface plot from meshgrid data
98    pub fn new(x_data: Vec<f64>, y_data: Vec<f64>, z_data: Vec<Vec<f64>>) -> Result<Self, String> {
99        // Validate dimensions
100        if z_data.len() != x_data.len() {
101            return Err(format!(
102                "Z data rows ({}) must match X data length ({})",
103                z_data.len(),
104                x_data.len()
105            ));
106        }
107
108        for (i, row) in z_data.iter().enumerate() {
109            if row.len() != y_data.len() {
110                return Err(format!(
111                    "Z data row {} length ({}) must match Y data length ({})",
112                    i,
113                    row.len(),
114                    y_data.len()
115                ));
116            }
117        }
118
119        Ok(Self {
120            x_data,
121            y_data,
122            z_data,
123            colormap: ColorMap::default(),
124            shading_mode: ShadingMode::default(),
125            wireframe: false,
126            alpha: 1.0,
127            lighting_enabled: true,
128            ambient_strength: 0.2,
129            diffuse_strength: 0.8,
130            specular_strength: 0.5,
131            shininess: 32.0,
132            label: None,
133            visible: true,
134            vertices: None,
135            indices: None,
136            bounds: None,
137            dirty: true,
138        })
139    }
140
141    /// Create surface from a function
142    pub fn from_function<F>(
143        x_range: (f64, f64),
144        y_range: (f64, f64),
145        resolution: (usize, usize),
146        func: F,
147    ) -> Result<Self, String>
148    where
149        F: Fn(f64, f64) -> f64,
150    {
151        let (x_res, y_res) = resolution;
152        if x_res < 2 || y_res < 2 {
153            return Err("Resolution must be at least 2x2".to_string());
154        }
155
156        let x_data: Vec<f64> = (0..x_res)
157            .map(|i| x_range.0 + (x_range.1 - x_range.0) * i as f64 / (x_res - 1) as f64)
158            .collect();
159
160        let y_data: Vec<f64> = (0..y_res)
161            .map(|j| y_range.0 + (y_range.1 - y_range.0) * j as f64 / (y_res - 1) as f64)
162            .collect();
163
164        let z_data: Vec<Vec<f64>> = x_data
165            .iter()
166            .map(|&x| y_data.iter().map(|&y| func(x, y)).collect())
167            .collect();
168
169        Self::new(x_data, y_data, z_data)
170    }
171
172    /// Set color mapping
173    pub fn with_colormap(mut self, colormap: ColorMap) -> Self {
174        self.colormap = colormap;
175        self.dirty = true;
176        self
177    }
178
179    /// Set shading mode
180    pub fn with_shading(mut self, shading: ShadingMode) -> Self {
181        self.shading_mode = shading;
182        self.dirty = true;
183        self
184    }
185
186    /// Enable/disable wireframe
187    pub fn with_wireframe(mut self, enabled: bool) -> Self {
188        self.wireframe = enabled;
189        self.dirty = true;
190        self
191    }
192
193    /// Set transparency
194    pub fn with_alpha(mut self, alpha: f32) -> Self {
195        self.alpha = alpha.clamp(0.0, 1.0);
196        self.dirty = true;
197        self
198    }
199
200    /// Set plot label for legends
201    pub fn with_label<S: Into<String>>(mut self, label: S) -> Self {
202        self.label = Some(label.into());
203        self
204    }
205
206    /// Get the number of grid points
207    pub fn len(&self) -> usize {
208        self.x_data.len() * self.y_data.len()
209    }
210
211    /// Check if the surface has no data
212    pub fn is_empty(&self) -> bool {
213        self.x_data.is_empty() || self.y_data.is_empty()
214    }
215
216    /// Get the bounding box of the surface
217    pub fn bounds(&mut self) -> BoundingBox {
218        if self.dirty || self.bounds.is_none() {
219            self.compute_bounds();
220        }
221        self.bounds.unwrap()
222    }
223
224    /// Compute bounding box
225    fn compute_bounds(&mut self) {
226        let mut min_x = f32::INFINITY;
227        let mut max_x = f32::NEG_INFINITY;
228        let mut min_y = f32::INFINITY;
229        let mut max_y = f32::NEG_INFINITY;
230        let mut min_z = f32::INFINITY;
231        let mut max_z = f32::NEG_INFINITY;
232
233        for &x in &self.x_data {
234            min_x = min_x.min(x as f32);
235            max_x = max_x.max(x as f32);
236        }
237
238        for &y in &self.y_data {
239            min_y = min_y.min(y as f32);
240            max_y = max_y.max(y as f32);
241        }
242
243        for row in &self.z_data {
244            for &z in row {
245                if z.is_finite() {
246                    min_z = min_z.min(z as f32);
247                    max_z = max_z.max(z as f32);
248                }
249            }
250        }
251
252        self.bounds = Some(BoundingBox::new(
253            Vec3::new(min_x, min_y, min_z),
254            Vec3::new(max_x, max_y, max_z),
255        ));
256    }
257
258    /// Get plot statistics for debugging
259    pub fn statistics(&self) -> SurfaceStatistics {
260        let grid_size = self.x_data.len() * self.y_data.len();
261        let triangle_count = if self.x_data.len() > 1 && self.y_data.len() > 1 {
262            (self.x_data.len() - 1) * (self.y_data.len() - 1) * 2
263        } else {
264            0
265        };
266
267        SurfaceStatistics {
268            grid_points: grid_size,
269            triangle_count,
270            x_resolution: self.x_data.len(),
271            y_resolution: self.y_data.len(),
272            memory_usage: self.estimated_memory_usage(),
273        }
274    }
275
276    /// Estimate memory usage in bytes
277    pub fn estimated_memory_usage(&self) -> usize {
278        let data_size = std::mem::size_of::<f64>()
279            * (self.x_data.len() + self.y_data.len() + self.z_data.len() * self.y_data.len());
280
281        let vertices_size = self
282            .vertices
283            .as_ref()
284            .map_or(0, |v| v.len() * std::mem::size_of::<Vertex>());
285
286        let indices_size = self
287            .indices
288            .as_ref()
289            .map_or(0, |i| i.len() * std::mem::size_of::<u32>());
290
291        data_size + vertices_size + indices_size
292    }
293
294    /// Generate vertices for surface mesh
295    fn generate_vertices(&mut self) -> &Vec<Vertex> {
296        if self.dirty || self.vertices.is_none() {
297            println!(
298                "DEBUG: Generating surface vertices for {} x {} grid",
299                self.x_data.len(),
300                self.y_data.len()
301            );
302
303            let mut vertices = Vec::new();
304
305            // Find Z value range for color mapping
306            let mut min_z = f64::INFINITY;
307            let mut max_z = f64::NEG_INFINITY;
308            for row in &self.z_data {
309                for &z in row {
310                    if z.is_finite() {
311                        min_z = min_z.min(z);
312                        max_z = max_z.max(z);
313                    }
314                }
315            }
316            let z_range = max_z - min_z;
317
318            // Generate vertices for each grid point
319            for (i, &x) in self.x_data.iter().enumerate() {
320                for (j, &y) in self.y_data.iter().enumerate() {
321                    let z = self.z_data[i][j];
322                    let position = Vec3::new(x as f32, y as f32, z as f32);
323
324                    // Simple normal calculation (can be improved with proper gradients)
325                    let normal = Vec3::new(0.0, 0.0, 1.0); // Placeholder
326
327                    // Color based on Z value using colormap
328                    let t = if z_range > 0.0 {
329                        ((z - min_z) / z_range) as f32
330                    } else {
331                        0.5
332                    };
333                    let color_rgb = self.colormap.map_value(t);
334                    let color = Vec4::new(color_rgb.x, color_rgb.y, color_rgb.z, self.alpha);
335
336                    vertices.push(Vertex {
337                        position: position.to_array(),
338                        normal: normal.to_array(),
339                        color: color.to_array(),
340                        tex_coords: [
341                            i as f32 / (self.x_data.len() - 1).max(1) as f32,
342                            j as f32 / (self.y_data.len() - 1).max(1) as f32,
343                        ],
344                    });
345                }
346            }
347
348            println!("DEBUG: Generated {} vertices for surface", vertices.len());
349            self.vertices = Some(vertices);
350        }
351        self.vertices.as_ref().unwrap()
352    }
353
354    /// Generate indices for surface triangulation
355    fn generate_indices(&mut self) -> &Vec<u32> {
356        if self.dirty || self.indices.is_none() {
357            println!("DEBUG: Generating surface indices");
358
359            let mut indices = Vec::new();
360            let x_res = self.x_data.len();
361            let y_res = self.y_data.len();
362
363            // Generate triangle indices for surface mesh
364            for i in 0..x_res - 1 {
365                for j in 0..y_res - 1 {
366                    let base = (i * y_res + j) as u32;
367                    let next_row = base + y_res as u32;
368
369                    // Two triangles per quad
370                    // Triangle 1: (i,j), (i+1,j), (i,j+1)
371                    indices.push(base);
372                    indices.push(next_row);
373                    indices.push(base + 1);
374
375                    // Triangle 2: (i+1,j), (i+1,j+1), (i,j+1)
376                    indices.push(next_row);
377                    indices.push(next_row + 1);
378                    indices.push(base + 1);
379                }
380            }
381
382            println!("DEBUG: Generated {} indices for surface", indices.len());
383            self.indices = Some(indices);
384            self.dirty = false;
385        }
386        self.indices.as_ref().unwrap()
387    }
388
389    /// Generate complete render data for the graphics pipeline
390    pub fn render_data(&mut self) -> RenderData {
391        println!(
392            "DEBUG: SurfacePlot::render_data() called for {} x {} surface",
393            self.x_data.len(),
394            self.y_data.len()
395        );
396
397        let vertices = self.generate_vertices().clone();
398        let indices = self.generate_indices().clone();
399
400        println!(
401            "DEBUG: Surface render data: {} vertices, {} indices",
402            vertices.len(),
403            indices.len()
404        );
405
406        let material = Material {
407            albedo: Vec4::new(1.0, 1.0, 1.0, self.alpha),
408            ..Default::default()
409        };
410
411        let draw_call = DrawCall {
412            vertex_offset: 0,
413            vertex_count: vertices.len(),
414            index_offset: Some(0),
415            index_count: Some(indices.len()),
416            instance_count: 1,
417        };
418
419        println!("DEBUG: SurfacePlot render_data completed successfully");
420
421        RenderData {
422            pipeline_type: if self.wireframe {
423                PipelineType::Lines
424            } else {
425                PipelineType::Triangles
426            },
427            vertices,
428            indices: Some(indices),
429            material,
430            draw_calls: vec![draw_call],
431        }
432    }
433}
434
435/// Surface plot performance and data statistics
436#[derive(Debug, Clone)]
437pub struct SurfaceStatistics {
438    pub grid_points: usize,
439    pub triangle_count: usize,
440    pub x_resolution: usize,
441    pub y_resolution: usize,
442    pub memory_usage: usize,
443}
444
445impl ColorMap {
446    /// Map a normalized value [0,1] to a color
447    pub fn map_value(&self, t: f32) -> Vec3 {
448        let t = t.clamp(0.0, 1.0);
449
450        match self {
451            ColorMap::Jet => self.jet_colormap(t),
452            ColorMap::Hot => self.hot_colormap(t),
453            ColorMap::Cool => self.cool_colormap(t),
454            ColorMap::Spring => self.spring_colormap(t),
455            ColorMap::Summer => self.summer_colormap(t),
456            ColorMap::Autumn => self.autumn_colormap(t),
457            ColorMap::Winter => self.winter_colormap(t),
458            ColorMap::Gray => Vec3::splat(t),
459            ColorMap::Bone => self.bone_colormap(t),
460            ColorMap::Copper => self.copper_colormap(t),
461            ColorMap::Pink => self.pink_colormap(t),
462            ColorMap::Lines => self.lines_colormap(t),
463            ColorMap::Viridis => self.viridis_colormap(t),
464            ColorMap::Plasma => self.plasma_colormap(t),
465            ColorMap::Inferno => self.inferno_colormap(t),
466            ColorMap::Magma => self.magma_colormap(t),
467            ColorMap::Turbo => self.turbo_colormap(t),
468            ColorMap::Parula => self.parula_colormap(t),
469            ColorMap::Custom(min_color, max_color) => {
470                min_color.truncate().lerp(max_color.truncate(), t)
471            }
472        }
473    }
474
475    /// MATLAB Jet colormap
476    fn jet_colormap(&self, t: f32) -> Vec3 {
477        let r = (1.5 - 4.0 * (t - 0.75).abs()).clamp(0.0, 1.0);
478        let g = (1.5 - 4.0 * (t - 0.5).abs()).clamp(0.0, 1.0);
479        let b = (1.5 - 4.0 * (t - 0.25).abs()).clamp(0.0, 1.0);
480        Vec3::new(r, g, b)
481    }
482
483    /// Hot colormap (black -> red -> yellow -> white)
484    fn hot_colormap(&self, t: f32) -> Vec3 {
485        if t < 1.0 / 3.0 {
486            Vec3::new(3.0 * t, 0.0, 0.0)
487        } else if t < 2.0 / 3.0 {
488            Vec3::new(1.0, 3.0 * t - 1.0, 0.0)
489        } else {
490            Vec3::new(1.0, 1.0, 3.0 * t - 2.0)
491        }
492    }
493
494    /// Cool colormap (cyan -> magenta)
495    fn cool_colormap(&self, t: f32) -> Vec3 {
496        Vec3::new(t, 1.0 - t, 1.0)
497    }
498
499    /// Viridis colormap (perceptually uniform)
500    fn viridis_colormap(&self, t: f32) -> Vec3 {
501        // Simplified Viridis approximation
502        let r = (0.267004 + t * (0.993248 - 0.267004)).clamp(0.0, 1.0);
503        let g = (0.004874 + t * (0.906157 - 0.004874)).clamp(0.0, 1.0);
504        let b = (0.329415 + t * (0.143936 - 0.329415) + t * t * 0.5).clamp(0.0, 1.0);
505        Vec3::new(r, g, b)
506    }
507
508    /// Plasma colormap (perceptually uniform)
509    fn plasma_colormap(&self, t: f32) -> Vec3 {
510        // Simplified Plasma approximation
511        let r = (0.050383 + t * (0.940015 - 0.050383)).clamp(0.0, 1.0);
512        let g = (0.029803 + t * (0.975158 - 0.029803) * (1.0 - t)).clamp(0.0, 1.0);
513        let b = (0.527975 + t * (0.131326 - 0.527975)).clamp(0.0, 1.0);
514        Vec3::new(r, g, b)
515    }
516
517    /// Spring colormap (magenta -> yellow)
518    fn spring_colormap(&self, t: f32) -> Vec3 {
519        Vec3::new(1.0, t, 1.0 - t)
520    }
521
522    /// Summer colormap (green -> yellow)
523    fn summer_colormap(&self, t: f32) -> Vec3 {
524        Vec3::new(t, 0.5 + 0.5 * t, 0.4)
525    }
526
527    /// Autumn colormap (red -> yellow)
528    fn autumn_colormap(&self, t: f32) -> Vec3 {
529        Vec3::new(1.0, t, 0.0)
530    }
531
532    /// Winter colormap (blue -> green)
533    fn winter_colormap(&self, t: f32) -> Vec3 {
534        Vec3::new(0.0, t, 1.0 - 0.5 * t)
535    }
536
537    /// Bone colormap (black -> white with blue tint)
538    fn bone_colormap(&self, t: f32) -> Vec3 {
539        if t < 3.0 / 8.0 {
540            Vec3::new(7.0 / 8.0 * t, 7.0 / 8.0 * t, 29.0 / 24.0 * t)
541        } else {
542            Vec3::new(
543                (29.0 + 7.0 * t) / 24.0,
544                (29.0 + 7.0 * t) / 24.0,
545                (29.0 + 7.0 * t) / 24.0,
546            )
547        }
548    }
549
550    /// Copper colormap (black -> copper)
551    fn copper_colormap(&self, t: f32) -> Vec3 {
552        Vec3::new((1.25 * t).min(1.0), 0.7812 * t, 0.4975 * t)
553    }
554
555    /// Pink colormap (black -> pink -> white)
556    fn pink_colormap(&self, t: f32) -> Vec3 {
557        let sqrt_t = t.sqrt();
558        if t < 3.0 / 8.0 {
559            Vec3::new(14.0 / 9.0 * sqrt_t, 2.0 / 3.0 * sqrt_t, 2.0 / 3.0 * sqrt_t)
560        } else {
561            Vec3::new(
562                2.0 * sqrt_t - 1.0 / 3.0,
563                8.0 / 9.0 * sqrt_t + 1.0 / 3.0,
564                8.0 / 9.0 * sqrt_t + 1.0 / 3.0,
565            )
566        }
567    }
568
569    /// Lines colormap (cycling through basic colors)
570    fn lines_colormap(&self, t: f32) -> Vec3 {
571        let _phase = (t * 7.0) % 1.0; // For future use in color transitions
572        let index = (t * 7.0) as usize % 7;
573        match index {
574            0 => Vec3::new(0.0, 0.0, 1.0),    // Blue
575            1 => Vec3::new(0.0, 0.5, 0.0),    // Green
576            2 => Vec3::new(1.0, 0.0, 0.0),    // Red
577            3 => Vec3::new(0.0, 0.75, 0.75),  // Cyan
578            4 => Vec3::new(0.75, 0.0, 0.75),  // Magenta
579            5 => Vec3::new(0.75, 0.75, 0.0),  // Yellow
580            _ => Vec3::new(0.25, 0.25, 0.25), // Dark gray
581        }
582    }
583
584    /// Inferno colormap (perceptually uniform)
585    fn inferno_colormap(&self, t: f32) -> Vec3 {
586        // Simplified Inferno approximation
587        let r = (0.001462 + t * (0.988362 - 0.001462)).clamp(0.0, 1.0);
588        let g = (0.000466 + t * t * (0.982895 - 0.000466)).clamp(0.0, 1.0);
589        let b = (0.013866 + t * (1.0 - t) * (0.416065 - 0.013866)).clamp(0.0, 1.0);
590        Vec3::new(r, g, b)
591    }
592
593    /// Magma colormap (perceptually uniform)
594    fn magma_colormap(&self, t: f32) -> Vec3 {
595        // Simplified Magma approximation
596        let r = (0.001462 + t * (0.987053 - 0.001462)).clamp(0.0, 1.0);
597        let g = (0.000466 + t * t * (0.991438 - 0.000466)).clamp(0.0, 1.0);
598        let b = (0.013866 + t * (0.644237 - 0.013866) * (1.0 - t)).clamp(0.0, 1.0);
599        Vec3::new(r, g, b)
600    }
601
602    /// Turbo colormap (improved rainbow)
603    fn turbo_colormap(&self, t: f32) -> Vec3 {
604        // Simplified Turbo approximation (Google's improved rainbow)
605        let r = if t < 0.5 {
606            (0.13 + 0.87 * (2.0 * t).powf(0.25)).clamp(0.0, 1.0)
607        } else {
608            (0.8685 + 0.1315 * (2.0 * (1.0 - t)).powf(0.25)).clamp(0.0, 1.0)
609        };
610
611        let g = if t < 0.25 {
612            4.0 * t
613        } else if t < 0.75 {
614            1.0
615        } else {
616            1.0 - 4.0 * (t - 0.75)
617        }
618        .clamp(0.0, 1.0);
619
620        let b = if t < 0.5 {
621            (0.8 * (1.0 - 2.0 * t).powf(0.25)).clamp(0.0, 1.0)
622        } else {
623            (0.1 + 0.9 * (2.0 * t - 1.0).powf(0.25)).clamp(0.0, 1.0)
624        };
625
626        Vec3::new(r, g, b)
627    }
628
629    /// Parula colormap (MATLAB's default)
630    fn parula_colormap(&self, t: f32) -> Vec3 {
631        // Simplified Parula approximation
632        let r = if t < 0.25 {
633            0.2081 * (1.0 - t)
634        } else if t < 0.5 {
635            t - 0.25
636        } else if t < 0.75 {
637            1.0
638        } else {
639            1.0 - 0.5 * (t - 0.75)
640        }
641        .clamp(0.0, 1.0);
642
643        let g = if t < 0.125 {
644            0.1663 * t / 0.125
645        } else if t < 0.375 {
646            0.1663 + (0.7079 - 0.1663) * (t - 0.125) / 0.25
647        } else if t < 0.625 {
648            0.7079 + (0.9839 - 0.7079) * (t - 0.375) / 0.25
649        } else {
650            0.9839 * (1.0 - (t - 0.625) / 0.375)
651        }
652        .clamp(0.0, 1.0);
653
654        let b = if t < 0.25 {
655            0.5 + 0.5 * t / 0.25
656        } else if t < 0.5 {
657            1.0
658        } else {
659            1.0 - 2.0 * (t - 0.5)
660        }
661        .clamp(0.0, 1.0);
662
663        Vec3::new(r, g, b)
664    }
665
666    /// Default colormap fallback
667    #[allow(dead_code)] // Fallback method for colormap errors
668    fn default_colormap(&self, t: f32) -> Vec3 {
669        // Use a simple RGB transition as fallback
670        if t < 0.5 {
671            Vec3::new(0.0, 2.0 * t, 1.0 - 2.0 * t)
672        } else {
673            Vec3::new(2.0 * (t - 0.5), 1.0 - 2.0 * (t - 0.5), 0.0)
674        }
675    }
676}
677
678/// MATLAB-compatible surface plot creation utilities
679pub mod matlab_compat {
680    use super::*;
681
682    /// Create a surface plot (equivalent to MATLAB's `surf(X, Y, Z)`)
683    pub fn surf(x: Vec<f64>, y: Vec<f64>, z: Vec<Vec<f64>>) -> Result<SurfacePlot, String> {
684        SurfacePlot::new(x, y, z)
685    }
686
687    /// Create a mesh plot (wireframe surface)
688    pub fn mesh(x: Vec<f64>, y: Vec<f64>, z: Vec<Vec<f64>>) -> Result<SurfacePlot, String> {
689        Ok(SurfacePlot::new(x, y, z)?
690            .with_wireframe(true)
691            .with_shading(ShadingMode::None))
692    }
693
694    /// Create surface from meshgrid
695    pub fn meshgrid_surf(
696        x_range: (f64, f64),
697        y_range: (f64, f64),
698        resolution: (usize, usize),
699        func: impl Fn(f64, f64) -> f64,
700    ) -> Result<SurfacePlot, String> {
701        SurfacePlot::from_function(x_range, y_range, resolution, func)
702    }
703
704    /// Create surface with specific colormap
705    pub fn surf_with_colormap(
706        x: Vec<f64>,
707        y: Vec<f64>,
708        z: Vec<Vec<f64>>,
709        colormap: &str,
710    ) -> Result<SurfacePlot, String> {
711        let cmap = match colormap {
712            "jet" => ColorMap::Jet,
713            "hot" => ColorMap::Hot,
714            "cool" => ColorMap::Cool,
715            "viridis" => ColorMap::Viridis,
716            "plasma" => ColorMap::Plasma,
717            "gray" | "grey" => ColorMap::Gray,
718            _ => return Err(format!("Unknown colormap: {colormap}")),
719        };
720
721        Ok(SurfacePlot::new(x, y, z)?.with_colormap(cmap))
722    }
723}
724
725#[cfg(test)]
726mod tests {
727    use super::*;
728
729    #[test]
730    fn test_surface_plot_creation() {
731        let x = vec![0.0, 1.0, 2.0];
732        let y = vec![0.0, 1.0];
733        let z = vec![vec![0.0, 1.0], vec![1.0, 2.0], vec![2.0, 3.0]];
734
735        let surface = SurfacePlot::new(x, y, z).unwrap();
736
737        assert_eq!(surface.x_data.len(), 3);
738        assert_eq!(surface.y_data.len(), 2);
739        assert_eq!(surface.z_data.len(), 3);
740        assert_eq!(surface.z_data[0].len(), 2);
741        assert!(surface.visible);
742    }
743
744    #[test]
745    fn test_surface_from_function() {
746        let surface =
747            SurfacePlot::from_function((-2.0, 2.0), (-2.0, 2.0), (10, 10), |x, y| x * x + y * y)
748                .unwrap();
749
750        assert_eq!(surface.x_data.len(), 10);
751        assert_eq!(surface.y_data.len(), 10);
752        assert_eq!(surface.z_data.len(), 10);
753
754        // Check that function is evaluated correctly
755        assert_eq!(surface.z_data[0][0], 8.0); // (-2)^2 + (-2)^2 = 8
756    }
757
758    #[test]
759    fn test_surface_validation() {
760        let x = vec![0.0, 1.0];
761        let y = vec![0.0, 1.0, 2.0];
762        let z = vec![
763            vec![0.0, 1.0], // Wrong: should have 3 elements to match y
764            vec![1.0, 2.0],
765        ];
766
767        assert!(SurfacePlot::new(x, y, z).is_err());
768    }
769
770    #[test]
771    fn test_surface_styling() {
772        let x = vec![0.0, 1.0];
773        let y = vec![0.0, 1.0];
774        let z = vec![vec![0.0, 1.0], vec![1.0, 2.0]];
775
776        let surface = SurfacePlot::new(x, y, z)
777            .unwrap()
778            .with_colormap(ColorMap::Hot)
779            .with_wireframe(true)
780            .with_alpha(0.8)
781            .with_label("Test Surface");
782
783        assert_eq!(surface.colormap, ColorMap::Hot);
784        assert!(surface.wireframe);
785        assert_eq!(surface.alpha, 0.8);
786        assert_eq!(surface.label, Some("Test Surface".to_string()));
787    }
788
789    #[test]
790    fn test_colormap_mapping() {
791        let jet = ColorMap::Jet;
792
793        // Test boundary values
794        let color_0 = jet.map_value(0.0);
795        let color_1 = jet.map_value(1.0);
796
797        assert!(color_0.x >= 0.0 && color_0.x <= 1.0);
798        assert!(color_1.x >= 0.0 && color_1.x <= 1.0);
799
800        // Test that different values give different colors
801        let color_mid = jet.map_value(0.5);
802        assert_ne!(color_0, color_mid);
803        assert_ne!(color_mid, color_1);
804    }
805
806    #[test]
807    fn test_surface_statistics() {
808        let x = vec![0.0, 1.0, 2.0, 3.0];
809        let y = vec![0.0, 1.0, 2.0];
810        let z = vec![
811            vec![0.0, 1.0, 2.0],
812            vec![1.0, 2.0, 3.0],
813            vec![2.0, 3.0, 4.0],
814            vec![3.0, 4.0, 5.0],
815        ];
816
817        let surface = SurfacePlot::new(x, y, z).unwrap();
818        let stats = surface.statistics();
819
820        assert_eq!(stats.grid_points, 12); // 4 * 3
821        assert_eq!(stats.triangle_count, 12); // (4-1) * (3-1) * 2
822        assert_eq!(stats.x_resolution, 4);
823        assert_eq!(stats.y_resolution, 3);
824        assert!(stats.memory_usage > 0);
825    }
826
827    #[test]
828    fn test_matlab_compat() {
829        use super::matlab_compat::*;
830
831        let x = vec![0.0, 1.0];
832        let y = vec![0.0, 1.0];
833        let z = vec![vec![0.0, 1.0], vec![1.0, 2.0]];
834
835        let surface = surf(x.clone(), y.clone(), z.clone()).unwrap();
836        assert!(!surface.wireframe);
837
838        let mesh_plot = mesh(x.clone(), y.clone(), z.clone()).unwrap();
839        assert!(mesh_plot.wireframe);
840
841        let colormap_surface = surf_with_colormap(x, y, z, "viridis").unwrap();
842        assert_eq!(colormap_surface.colormap, ColorMap::Viridis);
843    }
844}