1use runmat_builtins::Tensor;
7use runmat_macros::runtime_builtin;
8use runmat_plot::plots::figure::PlotElement;
9use runmat_plot::plots::*;
10use std::env;
11
12fn get_plotting_mode() -> PlottingMode {
14 if let Ok(mode) = env::var("RUSTMAT_PLOT_MODE") {
16 match mode.to_lowercase().as_str() {
17 "gui" => PlottingMode::Interactive,
18 "headless" => PlottingMode::Static,
19 "jupyter" => PlottingMode::Jupyter,
20 _ => PlottingMode::Auto,
21 }
22 } else {
23 PlottingMode::Auto
24 }
25}
26
27#[derive(Debug, Clone, Copy)]
29enum PlottingMode {
30 Auto,
32 Interactive,
34 Static,
36 Jupyter,
38}
39
40fn execute_plot(mut figure: Figure) -> Result<String, String> {
42 match get_plotting_mode() {
43 PlottingMode::Interactive => interactive_export(&mut figure),
44 PlottingMode::Static => static_export(&mut figure, "plot.png"),
45 PlottingMode::Jupyter => jupyter_export(&mut figure),
46 PlottingMode::Auto => {
47 if env::var("JPY_PARENT_PID").is_ok() || env::var("JUPYTER_RUNTIME_DIR").is_ok() {
49 jupyter_export(&mut figure)
50 } else {
51 interactive_export(&mut figure)
52 }
53 }
54 }
55}
56
57fn interactive_export(figure: &mut Figure) -> Result<String, String> {
59 let figure_clone = figure.clone();
61
62 match runmat_plot::show_interactive_platform_optimal(figure_clone) {
64 Ok(result) => Ok(result),
65 Err(e) => {
66 Err(format!(
69 "Interactive plotting failed: {e}. Please check GPU/GUI system setup."
70 ))
71 }
72 }
73}
74
75fn static_export(figure: &mut Figure, filename: &str) -> Result<String, String> {
77 if figure.is_empty() {
80 return Err("No plots found in figure to export".to_string());
81 }
82
83 match figure.get_plot_mut(0) {
84 Some(PlotElement::Line(line_plot)) => {
85 let x_data: Vec<f64> = line_plot.x_data.to_vec();
86 let y_data: Vec<f64> = line_plot.y_data.to_vec();
87
88 runmat_plot::plot_line(
89 &x_data,
90 &y_data,
91 filename,
92 runmat_plot::PlotOptions::default(),
93 )
94 .map_err(|e| format!("Plot export failed: {e}"))?;
95
96 Ok(format!("Plot saved to {filename}"))
97 }
98 Some(PlotElement::Scatter(scatter_plot)) => {
99 let x_data: Vec<f64> = scatter_plot.x_data.to_vec();
100 let y_data: Vec<f64> = scatter_plot.y_data.to_vec();
101
102 runmat_plot::plot_scatter(
103 &x_data,
104 &y_data,
105 filename,
106 runmat_plot::PlotOptions::default(),
107 )
108 .map_err(|e| format!("Plot export failed: {e}"))?;
109
110 Ok(format!("Scatter plot saved to {filename}"))
111 }
112 Some(PlotElement::Bar(bar_chart)) => {
113 let values: Vec<f64> = bar_chart.values.to_vec();
114
115 runmat_plot::plot_bar(
116 &bar_chart.labels,
117 &values,
118 filename,
119 runmat_plot::PlotOptions::default(),
120 )
121 .map_err(|e| format!("Plot export failed: {e}"))?;
122
123 Ok(format!("Bar chart saved to {filename}"))
124 }
125 Some(PlotElement::Histogram(histogram)) => {
126 let data: Vec<f64> = histogram.data.to_vec();
127
128 runmat_plot::plot_histogram(
129 &data,
130 histogram.bins,
131 filename,
132 runmat_plot::PlotOptions::default(),
133 )
134 .map_err(|e| format!("Plot export failed: {e}"))?;
135
136 Ok(format!("Histogram saved to {filename}"))
137 }
138 Some(PlotElement::PointCloud(point_cloud)) => {
139 let x_data: Vec<f64> = point_cloud
141 .positions
142 .iter()
143 .map(|pos| pos.x as f64)
144 .collect();
145 let y_data: Vec<f64> = point_cloud
146 .positions
147 .iter()
148 .map(|pos| pos.y as f64)
149 .collect();
150
151 runmat_plot::plot_scatter(
152 &x_data,
153 &y_data,
154 filename,
155 runmat_plot::PlotOptions::default(),
156 )
157 .map_err(|e| format!("Point cloud export failed: {e}"))?;
158
159 Ok(format!("Point cloud (2D projection) saved to {filename}"))
160 }
161 None => Err("No plots found in figure to export".to_string()),
162 }
163}
164
165#[cfg(feature = "jupyter")]
167fn jupyter_export(figure: &mut Figure) -> Result<String, String> {
168 use runmat_plot::jupyter::JupyterBackend;
169 let mut backend = JupyterBackend::new();
170 backend.display_figure(figure)
171}
172
173#[cfg(not(feature = "jupyter"))]
174fn jupyter_export(_figure: &mut Figure) -> Result<String, String> {
175 Err("Jupyter feature not enabled".to_string())
176}
177
178fn extract_numeric_vector(matrix: &Tensor) -> Vec<f64> {
180 matrix.data.clone()
181}
182
183#[runtime_builtin(name = "plot")]
184fn plot_builtin(x: Tensor, y: Tensor) -> Result<String, String> {
185 let x_data = extract_numeric_vector(&x);
186 let y_data = extract_numeric_vector(&y);
187
188 if x_data.len() != y_data.len() {
189 return Err("X and Y data must have the same length".to_string());
190 }
191
192 let line_plot = LinePlot::new(x_data.clone(), y_data.clone())
194 .map_err(|e| format!("Failed to create line plot: {e}"))?
195 .with_label("Data")
196 .with_style(
197 glam::Vec4::new(0.0, 0.4, 0.8, 1.0), 3.0,
199 LineStyle::Solid,
200 );
201
202 let mut figure = Figure::new()
203 .with_title("Plot")
204 .with_labels("X", "Y")
205 .with_grid(true);
206
207 figure.add_line_plot(line_plot);
208
209 execute_plot(figure)
210}
211
212#[runtime_builtin(name = "scatter")]
213fn scatter_builtin(x: Tensor, y: Tensor) -> Result<String, String> {
214 let x_data = extract_numeric_vector(&x);
215 let y_data = extract_numeric_vector(&y);
216
217 if x_data.len() != y_data.len() {
218 return Err("X and Y data must have the same length".to_string());
219 }
220
221 let scatter_plot = ScatterPlot::new(x_data, y_data)
223 .map_err(|e| format!("Failed to create scatter plot: {e}"))?
224 .with_label("Data")
225 .with_style(
226 glam::Vec4::new(0.8, 0.2, 0.2, 1.0), 5.0,
228 MarkerStyle::Circle,
229 );
230
231 let mut figure = Figure::new()
232 .with_title("Scatter Plot")
233 .with_labels("X", "Y")
234 .with_grid(true);
235
236 figure.add_scatter_plot(scatter_plot);
237
238 execute_plot(figure)
239}
240
241#[runtime_builtin(name = "bar")]
242fn bar_builtin(values: Tensor) -> Result<String, String> {
243 let data = extract_numeric_vector(&values);
244
245 let labels: Vec<String> = (1..=data.len()).map(|i| format!("Item {i}")).collect();
247
248 let bar_chart = BarChart::new(labels, data)
250 .map_err(|e| format!("Failed to create bar chart: {e}"))?
251 .with_label("Values")
252 .with_style(glam::Vec4::new(0.2, 0.6, 0.3, 1.0), 0.8); let mut figure = Figure::new()
255 .with_title("Bar Chart")
256 .with_labels("Categories", "Values")
257 .with_grid(true);
258
259 figure.add_bar_chart(bar_chart);
260
261 execute_plot(figure)
262}
263
264#[runtime_builtin(name = "hist")]
265fn hist_builtin(values: Tensor) -> Result<String, String> {
266 let data = extract_numeric_vector(&values);
267 let bins = 10; let histogram = Histogram::new(data, bins)
271 .map_err(|e| format!("Failed to create histogram: {e}"))?
272 .with_label("Frequency")
273 .with_style(glam::Vec4::new(0.6, 0.3, 0.7, 1.0), false); let mut figure = Figure::new()
276 .with_title("Histogram")
277 .with_labels("Values", "Frequency")
278 .with_grid(true);
279
280 figure.add_histogram(histogram);
281
282 execute_plot(figure)
283}
284
285#[runtime_builtin(name = "surf")]
288fn surf_builtin(x: Tensor, y: Tensor, z: Tensor) -> Result<String, String> {
289 let x_data = extract_numeric_vector(&x);
292 let y_data = extract_numeric_vector(&y);
293 let z_data_flat = extract_numeric_vector(&z);
294
295 let grid_size = (z_data_flat.len() as f64).sqrt() as usize;
297 if grid_size * grid_size != z_data_flat.len() {
298 return Err("Z data must form a square grid".to_string());
299 }
300
301 let mut z_grid = Vec::new();
302 for i in 0..grid_size {
303 let mut row = Vec::new();
304 for j in 0..grid_size {
305 row.push(z_data_flat[i * grid_size + j]);
306 }
307 z_grid.push(row);
308 }
309
310 let surface = SurfacePlot::new(x_data, y_data, z_grid)
312 .map_err(|e| format!("Failed to create surface plot: {e}"))?
313 .with_colormap(ColorMap::Viridis)
314 .with_label("Surface");
315
316 Ok(format!(
318 "3D Surface plot created with {} points",
319 surface.len()
320 ))
321}
322
323#[runtime_builtin(name = "scatter3")]
324fn scatter3_builtin(x: Tensor, y: Tensor, z: Tensor) -> Result<String, String> {
325 let x_data = extract_numeric_vector(&x);
326 let y_data = extract_numeric_vector(&y);
327 let z_data = extract_numeric_vector(&z);
328
329 if x_data.len() != y_data.len() || y_data.len() != z_data.len() {
330 return Err("X, Y, and Z data must have the same length".to_string());
331 }
332
333 println!("DEBUG: Creating scatter3 plot with {} points", x_data.len());
334
335 let positions: Vec<glam::Vec3> = x_data
337 .iter()
338 .zip(y_data.iter())
339 .zip(z_data.iter())
340 .map(|((&x, &y), &z)| glam::Vec3::new(x as f32, y as f32, z as f32))
341 .collect();
342
343 let point_cloud = PointCloudPlot::new(positions)
345 .with_default_color(glam::Vec4::new(1.0, 0.6, 0.2, 1.0)) .with_default_size(8.0) .with_point_style(PointStyle::Square) .with_colormap(ColorMap::Hot) .with_label("3D Point Cloud");
350
351 println!(
352 "DEBUG: PointCloudPlot created successfully with {} points",
353 point_cloud.len()
354 );
355
356 let mut figure = Figure::new()
358 .with_title("High-Performance 3D Point Cloud")
359 .with_labels("X", "Y"); figure.add_point_cloud_plot(point_cloud);
363
364 println!("DEBUG: 3D Point cloud added to figure, executing plot...");
365
366 execute_plot(figure)
368}
369
370#[runtime_builtin(name = "mesh")]
371fn mesh_builtin(x: Tensor, y: Tensor, z: Tensor) -> Result<String, String> {
372 let result = surf_builtin(x, y, z)?;
374 Ok(result.replace("Surface", "Mesh (wireframe)"))
375}