runmat_kernel/
jupyter_plotting.rs

1//! Jupyter plotting integration for the RunMat kernel
2//!
3//! Provides seamless integration between the RunMat kernel and the plotting system,
4//! enabling automatic plot display in Jupyter notebooks.
5
6#[cfg(feature = "jupyter")]
7use crate::{KernelError, Result};
8#[cfg(feature = "jupyter")]
9use runmat_plot::jupyter::{JupyterBackend, OutputFormat};
10#[cfg(feature = "jupyter")]
11use runmat_plot::plots::Figure;
12#[cfg(feature = "jupyter")]
13use serde_json::Value as JsonValue;
14#[cfg(feature = "jupyter")]
15use std::collections::HashMap;
16
17/// Jupyter plotting manager for the RunMat kernel
18#[cfg(feature = "jupyter")]
19pub struct JupyterPlottingManager {
20    /// Jupyter backend for rendering plots
21    backend: JupyterBackend,
22    /// Plotting configuration from global config
23    config: JupyterPlottingConfig,
24    /// Active plots in the current session
25    active_plots: HashMap<String, Figure>,
26    /// Plot counter for unique IDs
27    plot_counter: u64,
28}
29
30/// Jupyter plotting configuration
31#[cfg(feature = "jupyter")]
32#[derive(Debug, Clone)]
33pub struct JupyterPlottingConfig {
34    /// Default output format
35    pub output_format: OutputFormat,
36    /// Auto-display plots after creation
37    pub auto_display: bool,
38    /// Maximum number of plots to keep in memory
39    pub max_plots: usize,
40    /// Enable inline display
41    pub inline_display: bool,
42    /// Image width for exports
43    pub image_width: u32,
44    /// Image height for exports
45    pub image_height: u32,
46}
47
48/// Display data for Jupyter protocol
49#[derive(Debug, Clone)]
50pub struct DisplayData {
51    /// MIME type -> content mapping
52    pub data: HashMap<String, JsonValue>,
53    /// Metadata for the display
54    pub metadata: HashMap<String, JsonValue>,
55    /// Transient data
56    pub transient: HashMap<String, JsonValue>,
57}
58
59#[cfg(feature = "jupyter")]
60impl Default for JupyterPlottingConfig {
61    fn default() -> Self {
62        Self {
63            output_format: OutputFormat::HTML,
64            auto_display: true,
65            max_plots: 100,
66            inline_display: true,
67            image_width: 800,
68            image_height: 600,
69        }
70    }
71}
72
73#[cfg(feature = "jupyter")]
74impl JupyterPlottingManager {
75    /// Create a new Jupyter plotting manager
76    pub fn new() -> Self {
77        Self::with_config(JupyterPlottingConfig::default())
78    }
79
80    /// Create manager with specific configuration
81    pub fn with_config(config: JupyterPlottingConfig) -> Self {
82        let backend = match config.output_format {
83            OutputFormat::HTML => JupyterBackend::with_format(OutputFormat::HTML),
84            OutputFormat::PNG => JupyterBackend::with_format(OutputFormat::PNG),
85            OutputFormat::SVG => JupyterBackend::with_format(OutputFormat::SVG),
86            OutputFormat::Base64 => JupyterBackend::with_format(OutputFormat::Base64),
87            OutputFormat::PlotlyJSON => JupyterBackend::with_format(OutputFormat::PlotlyJSON),
88        };
89
90        Self {
91            backend,
92            config,
93            active_plots: HashMap::new(),
94            plot_counter: 0,
95        }
96    }
97
98    /// Register a new plot and optionally display it
99    pub fn register_plot(&mut self, mut figure: Figure) -> Result<Option<DisplayData>> {
100        self.plot_counter += 1;
101        let plot_id = format!("plot_{}", self.plot_counter);
102
103        // Store the plot
104        self.active_plots.insert(plot_id.clone(), figure.clone());
105
106        // Clean up old plots if we exceed the maximum
107        if self.active_plots.len() > self.config.max_plots {
108            self.cleanup_old_plots();
109        }
110
111        // Auto-display if enabled
112        if self.config.auto_display && self.config.inline_display {
113            let display_data = self.create_display_data(&mut figure)?;
114            Ok(Some(display_data))
115        } else {
116            Ok(None)
117        }
118    }
119
120    /// Create display data for a figure
121    pub fn create_display_data(&mut self, figure: &mut Figure) -> Result<DisplayData> {
122        let mut data = HashMap::new();
123        let mut metadata = HashMap::new();
124
125        // Generate content based on output format
126        match self.config.output_format {
127            OutputFormat::HTML => {
128                let html_content = self
129                    .backend
130                    .display_figure(figure)
131                    .map_err(|e| KernelError::Execution(format!("HTML generation failed: {e}")))?;
132
133                data.insert("text/html".to_string(), JsonValue::String(html_content));
134                metadata.insert(
135                    "text/html".to_string(),
136                    JsonValue::Object({
137                        let mut meta = serde_json::Map::new();
138                        meta.insert("isolated".to_string(), JsonValue::Bool(true));
139                        meta.insert(
140                            "width".to_string(),
141                            JsonValue::Number(self.config.image_width.into()),
142                        );
143                        meta.insert(
144                            "height".to_string(),
145                            JsonValue::Number(self.config.image_height.into()),
146                        );
147                        meta
148                    }),
149                );
150            }
151            OutputFormat::PNG => {
152                let png_content = self
153                    .backend
154                    .display_figure(figure)
155                    .map_err(|e| KernelError::Execution(format!("PNG generation failed: {e}")))?;
156
157                data.insert("text/html".to_string(), JsonValue::String(png_content));
158            }
159            OutputFormat::SVG => {
160                let svg_content = self
161                    .backend
162                    .display_figure(figure)
163                    .map_err(|e| KernelError::Execution(format!("SVG generation failed: {e}")))?;
164
165                data.insert("image/svg+xml".to_string(), JsonValue::String(svg_content));
166                metadata.insert(
167                    "image/svg+xml".to_string(),
168                    JsonValue::Object({
169                        let mut meta = serde_json::Map::new();
170                        meta.insert("isolated".to_string(), JsonValue::Bool(true));
171                        meta
172                    }),
173                );
174            }
175            OutputFormat::Base64 => {
176                let base64_content = self.backend.display_figure(figure).map_err(|e| {
177                    KernelError::Execution(format!("Base64 generation failed: {e}"))
178                })?;
179
180                data.insert("text/html".to_string(), JsonValue::String(base64_content));
181            }
182            OutputFormat::PlotlyJSON => {
183                let plotly_content = self.backend.display_figure(figure).map_err(|e| {
184                    KernelError::Execution(format!("Plotly generation failed: {e}"))
185                })?;
186
187                data.insert("text/html".to_string(), JsonValue::String(plotly_content));
188                metadata.insert(
189                    "text/html".to_string(),
190                    JsonValue::Object({
191                        let mut meta = serde_json::Map::new();
192                        meta.insert("isolated".to_string(), JsonValue::Bool(true));
193                        meta
194                    }),
195                );
196            }
197        }
198
199        // Add RunMat metadata
200        let mut transient = HashMap::new();
201        transient.insert(
202            "runmat_plot_id".to_string(),
203            JsonValue::String(format!("plot_{}", self.plot_counter)),
204        );
205        transient.insert(
206            "runmat_version".to_string(),
207            JsonValue::String("0.0.1".to_string()),
208        );
209
210        Ok(DisplayData {
211            data,
212            metadata,
213            transient,
214        })
215    }
216
217    /// Get a plot by ID
218    pub fn get_plot(&self, plot_id: &str) -> Option<&Figure> {
219        self.active_plots.get(plot_id)
220    }
221
222    /// List all active plots
223    pub fn list_plots(&self) -> Vec<String> {
224        self.active_plots.keys().cloned().collect()
225    }
226
227    /// Clear all plots
228    pub fn clear_plots(&mut self) {
229        self.active_plots.clear();
230        self.plot_counter = 0;
231    }
232
233    /// Update configuration
234    pub fn update_config(&mut self, config: JupyterPlottingConfig) {
235        self.config = config;
236
237        // Update backend format if needed
238        self.backend = match self.config.output_format {
239            OutputFormat::HTML => JupyterBackend::with_format(OutputFormat::HTML),
240            OutputFormat::PNG => JupyterBackend::with_format(OutputFormat::PNG),
241            OutputFormat::SVG => JupyterBackend::with_format(OutputFormat::SVG),
242            OutputFormat::Base64 => JupyterBackend::with_format(OutputFormat::Base64),
243            OutputFormat::PlotlyJSON => JupyterBackend::with_format(OutputFormat::PlotlyJSON),
244        };
245    }
246
247    /// Get current configuration
248    pub fn config(&self) -> &JupyterPlottingConfig {
249        &self.config
250    }
251
252    /// Clean up old plots to maintain memory limits
253    fn cleanup_old_plots(&mut self) {
254        // Simple cleanup: remove oldest plots
255        let mut plot_ids: Vec<String> = self.active_plots.keys().cloned().collect();
256        plot_ids.sort();
257
258        while self.active_plots.len() > self.config.max_plots {
259            if let Some(oldest_id) = plot_ids.first() {
260                self.active_plots.remove(oldest_id);
261                plot_ids.remove(0);
262            } else {
263                break;
264            }
265        }
266    }
267
268    /// Handle plot function calls from code
269    pub fn handle_plot_function(
270        &mut self,
271        function_name: &str,
272        args: &[JsonValue],
273    ) -> Result<Option<DisplayData>> {
274        println!(
275            "DEBUG: Handling plot function '{}' with {} args",
276            function_name,
277            args.len()
278        );
279
280        // Create appropriate plot based on function name
281        let mut figure = Figure::new();
282
283        match function_name {
284            "plot" => {
285                if args.len() >= 2 {
286                    // Extract x and y data from arguments
287                    let x_data = self.extract_numeric_array(&args[0])?;
288                    let y_data = self.extract_numeric_array(&args[1])?;
289
290                    if x_data.len() == y_data.len() {
291                        let line_plot =
292                            runmat_plot::plots::LinePlot::new(x_data, y_data).map_err(|e| {
293                                KernelError::Execution(format!("Failed to create line plot: {e}"))
294                            })?;
295                        figure.add_line_plot(line_plot);
296                    } else {
297                        return Err(KernelError::Execution(
298                            "X and Y data must have the same length".to_string(),
299                        ));
300                    }
301                }
302            }
303            "scatter" => {
304                if args.len() >= 2 {
305                    let x_data = self.extract_numeric_array(&args[0])?;
306                    let y_data = self.extract_numeric_array(&args[1])?;
307
308                    if x_data.len() == y_data.len() {
309                        let scatter_plot = runmat_plot::plots::ScatterPlot::new(x_data, y_data)
310                            .map_err(KernelError::Execution)?;
311                        figure.add_scatter_plot(scatter_plot);
312                    } else {
313                        return Err(KernelError::Execution(
314                            "X and Y data must have the same length".to_string(),
315                        ));
316                    }
317                }
318            }
319            "bar" => {
320                if !args.is_empty() {
321                    let y_data = self.extract_numeric_array(&args[0])?;
322                    let x_labels: Vec<String> = (0..y_data.len()).map(|i| format!("{i}")).collect();
323
324                    let bar_chart = runmat_plot::plots::BarChart::new(x_labels, y_data)
325                        .map_err(KernelError::Execution)?;
326                    figure.add_bar_chart(bar_chart);
327                }
328            }
329            "hist" => {
330                if !args.is_empty() {
331                    let data = self.extract_numeric_array(&args[0])?;
332                    let bins = if args.len() > 1 {
333                        self.extract_number(&args[1])? as usize
334                    } else {
335                        20
336                    };
337
338                    let histogram = runmat_plot::plots::Histogram::new(data, bins)
339                        .map_err(KernelError::Execution)?;
340                    figure.add_histogram(histogram);
341                }
342            }
343            _ => {
344                return Err(KernelError::Execution(format!(
345                    "Unknown plot function: {function_name}"
346                )));
347            }
348        }
349
350        // Register and potentially display the plot
351        self.register_plot(figure)
352    }
353
354    /// Extract numeric array from JSON value
355    fn extract_numeric_array(&self, value: &JsonValue) -> Result<Vec<f64>> {
356        match value {
357            JsonValue::Array(arr) => {
358                let mut result = Vec::new();
359                for item in arr {
360                    if let Some(num) = item.as_f64() {
361                        result.push(num);
362                    } else if let Some(num) = item.as_i64() {
363                        result.push(num as f64);
364                    } else {
365                        return Err(KernelError::Execution(
366                            "Array must contain only numbers".to_string(),
367                        ));
368                    }
369                }
370                Ok(result)
371            }
372            JsonValue::Number(num) => {
373                if let Some(val) = num.as_f64() {
374                    Ok(vec![val])
375                } else {
376                    Err(KernelError::Execution("Invalid number format".to_string()))
377                }
378            }
379            _ => Err(KernelError::Execution(
380                "Expected array or number".to_string(),
381            )),
382        }
383    }
384
385    /// Extract single number from JSON value
386    fn extract_number(&self, value: &JsonValue) -> Result<f64> {
387        match value {
388            JsonValue::Number(num) => num
389                .as_f64()
390                .ok_or_else(|| KernelError::Execution("Invalid number format".to_string())),
391            _ => Err(KernelError::Execution("Expected number".to_string())),
392        }
393    }
394}
395
396impl Default for JupyterPlottingManager {
397    fn default() -> Self {
398        Self::new()
399    }
400}
401
402/// Extension trait for ExecutionEngine to add Jupyter plotting support
403pub trait JupyterPlottingExtension {
404    /// Handle plot functions with automatic Jupyter display
405    fn handle_jupyter_plot(
406        &mut self,
407        function_name: &str,
408        args: &[JsonValue],
409    ) -> Result<Option<DisplayData>>;
410
411    /// Get the plotting manager
412    fn plotting_manager(&mut self) -> &mut JupyterPlottingManager;
413}
414
415// Note: This would need to be implemented on ExecutionEngine in the actual integration
416// For now, we provide the trait definition and placeholder implementation
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421
422    #[test]
423    fn test_jupyter_plotting_manager_creation() {
424        let manager = JupyterPlottingManager::new();
425        assert_eq!(manager.config.output_format, OutputFormat::HTML);
426        assert!(manager.config.auto_display);
427        assert_eq!(manager.active_plots.len(), 0);
428    }
429
430    #[test]
431    fn test_config_update() {
432        let mut manager = JupyterPlottingManager::new();
433
434        let new_config = JupyterPlottingConfig {
435            output_format: OutputFormat::SVG,
436            auto_display: false,
437            max_plots: 50,
438            inline_display: false,
439            image_width: 1024,
440            image_height: 768,
441        };
442
443        manager.update_config(new_config.clone());
444        assert_eq!(manager.config.output_format, OutputFormat::SVG);
445        assert!(!manager.config.auto_display);
446        assert_eq!(manager.config.max_plots, 50);
447    }
448
449    #[test]
450    fn test_plot_management() {
451        let mut manager = JupyterPlottingManager::new();
452        let figure = Figure::new().with_title("Test Plot");
453
454        // Register a plot
455        let display_data = manager.register_plot(figure).unwrap();
456        assert!(display_data.is_some());
457        assert_eq!(manager.active_plots.len(), 1);
458        assert_eq!(manager.list_plots().len(), 1);
459
460        // Clear plots
461        manager.clear_plots();
462        assert_eq!(manager.active_plots.len(), 0);
463        assert_eq!(manager.plot_counter, 0);
464    }
465
466    #[test]
467    fn test_extract_numeric_array() {
468        let manager = JupyterPlottingManager::new();
469
470        let json_array = JsonValue::Array(vec![
471            JsonValue::Number(serde_json::Number::from(1)),
472            JsonValue::Number(serde_json::Number::from(2)),
473            JsonValue::Number(serde_json::Number::from(3)),
474        ]);
475
476        let result = manager.extract_numeric_array(&json_array).unwrap();
477        assert_eq!(result, vec![1.0, 2.0, 3.0]);
478    }
479
480    #[test]
481    fn test_plot_function_handling() {
482        let mut manager = JupyterPlottingManager::new();
483
484        let x_data = JsonValue::Array(vec![
485            JsonValue::Number(serde_json::Number::from(1)),
486            JsonValue::Number(serde_json::Number::from(2)),
487            JsonValue::Number(serde_json::Number::from(3)),
488        ]);
489
490        let y_data = JsonValue::Array(vec![
491            JsonValue::Number(serde_json::Number::from(2)),
492            JsonValue::Number(serde_json::Number::from(4)),
493            JsonValue::Number(serde_json::Number::from(6)),
494        ]);
495
496        let result = manager
497            .handle_plot_function("plot", &[x_data, y_data])
498            .unwrap();
499        assert!(result.is_some());
500        assert_eq!(manager.active_plots.len(), 1);
501    }
502}