1#[cfg(feature = "jupyter")]
7use crate::{KernelError, Result};
8#[cfg(feature = "jupyter")]
9use runmat_plot::jupyter::{JupyterBackend, OutputFormat};
10#[cfg(feature = "jupyter")]
11use runmat_plot::plots::Figure;
12#[cfg(feature = "jupyter")]
13use serde_json::Value as JsonValue;
14#[cfg(feature = "jupyter")]
15use std::collections::HashMap;
16
17#[cfg(feature = "jupyter")]
19pub struct JupyterPlottingManager {
20 backend: JupyterBackend,
22 config: JupyterPlottingConfig,
24 active_plots: HashMap<String, Figure>,
26 plot_counter: u64,
28}
29
30#[cfg(feature = "jupyter")]
32#[derive(Debug, Clone)]
33pub struct JupyterPlottingConfig {
34 pub output_format: OutputFormat,
36 pub auto_display: bool,
38 pub max_plots: usize,
40 pub inline_display: bool,
42 pub image_width: u32,
44 pub image_height: u32,
46}
47
48#[derive(Debug, Clone)]
50pub struct DisplayData {
51 pub data: HashMap<String, JsonValue>,
53 pub metadata: HashMap<String, JsonValue>,
55 pub transient: HashMap<String, JsonValue>,
57}
58
59#[cfg(feature = "jupyter")]
60impl Default for JupyterPlottingConfig {
61 fn default() -> Self {
62 Self {
63 output_format: OutputFormat::HTML,
64 auto_display: true,
65 max_plots: 100,
66 inline_display: true,
67 image_width: 800,
68 image_height: 600,
69 }
70 }
71}
72
73#[cfg(feature = "jupyter")]
74impl JupyterPlottingManager {
75 pub fn new() -> Self {
77 Self::with_config(JupyterPlottingConfig::default())
78 }
79
80 pub fn with_config(config: JupyterPlottingConfig) -> Self {
82 let backend = match config.output_format {
83 OutputFormat::HTML => JupyterBackend::with_format(OutputFormat::HTML),
84 OutputFormat::PNG => JupyterBackend::with_format(OutputFormat::PNG),
85 OutputFormat::SVG => JupyterBackend::with_format(OutputFormat::SVG),
86 OutputFormat::Base64 => JupyterBackend::with_format(OutputFormat::Base64),
87 OutputFormat::PlotlyJSON => JupyterBackend::with_format(OutputFormat::PlotlyJSON),
88 };
89
90 Self {
91 backend,
92 config,
93 active_plots: HashMap::new(),
94 plot_counter: 0,
95 }
96 }
97
98 pub fn register_plot(&mut self, mut figure: Figure) -> Result<Option<DisplayData>> {
100 self.plot_counter += 1;
101 let plot_id = format!("plot_{}", self.plot_counter);
102
103 self.active_plots.insert(plot_id.clone(), figure.clone());
105
106 if self.active_plots.len() > self.config.max_plots {
108 self.cleanup_old_plots();
109 }
110
111 if self.config.auto_display && self.config.inline_display {
113 let display_data = self.create_display_data(&mut figure)?;
114 Ok(Some(display_data))
115 } else {
116 Ok(None)
117 }
118 }
119
120 pub fn create_display_data(&mut self, figure: &mut Figure) -> Result<DisplayData> {
122 let mut data = HashMap::new();
123 let mut metadata = HashMap::new();
124
125 match self.config.output_format {
127 OutputFormat::HTML => {
128 let html_content = self
129 .backend
130 .display_figure(figure)
131 .map_err(|e| KernelError::Execution(format!("HTML generation failed: {e}")))?;
132
133 data.insert("text/html".to_string(), JsonValue::String(html_content));
134 metadata.insert(
135 "text/html".to_string(),
136 JsonValue::Object({
137 let mut meta = serde_json::Map::new();
138 meta.insert("isolated".to_string(), JsonValue::Bool(true));
139 meta.insert(
140 "width".to_string(),
141 JsonValue::Number(self.config.image_width.into()),
142 );
143 meta.insert(
144 "height".to_string(),
145 JsonValue::Number(self.config.image_height.into()),
146 );
147 meta
148 }),
149 );
150 }
151 OutputFormat::PNG => {
152 let png_content = self
153 .backend
154 .display_figure(figure)
155 .map_err(|e| KernelError::Execution(format!("PNG generation failed: {e}")))?;
156
157 data.insert("text/html".to_string(), JsonValue::String(png_content));
158 }
159 OutputFormat::SVG => {
160 let svg_content = self
161 .backend
162 .display_figure(figure)
163 .map_err(|e| KernelError::Execution(format!("SVG generation failed: {e}")))?;
164
165 data.insert("image/svg+xml".to_string(), JsonValue::String(svg_content));
166 metadata.insert(
167 "image/svg+xml".to_string(),
168 JsonValue::Object({
169 let mut meta = serde_json::Map::new();
170 meta.insert("isolated".to_string(), JsonValue::Bool(true));
171 meta
172 }),
173 );
174 }
175 OutputFormat::Base64 => {
176 let base64_content = self.backend.display_figure(figure).map_err(|e| {
177 KernelError::Execution(format!("Base64 generation failed: {e}"))
178 })?;
179
180 data.insert("text/html".to_string(), JsonValue::String(base64_content));
181 }
182 OutputFormat::PlotlyJSON => {
183 let plotly_content = self.backend.display_figure(figure).map_err(|e| {
184 KernelError::Execution(format!("Plotly generation failed: {e}"))
185 })?;
186
187 data.insert("text/html".to_string(), JsonValue::String(plotly_content));
188 metadata.insert(
189 "text/html".to_string(),
190 JsonValue::Object({
191 let mut meta = serde_json::Map::new();
192 meta.insert("isolated".to_string(), JsonValue::Bool(true));
193 meta
194 }),
195 );
196 }
197 }
198
199 let mut transient = HashMap::new();
201 transient.insert(
202 "runmat_plot_id".to_string(),
203 JsonValue::String(format!("plot_{}", self.plot_counter)),
204 );
205 transient.insert(
206 "runmat_version".to_string(),
207 JsonValue::String("0.0.1".to_string()),
208 );
209
210 Ok(DisplayData {
211 data,
212 metadata,
213 transient,
214 })
215 }
216
217 pub fn get_plot(&self, plot_id: &str) -> Option<&Figure> {
219 self.active_plots.get(plot_id)
220 }
221
222 pub fn list_plots(&self) -> Vec<String> {
224 self.active_plots.keys().cloned().collect()
225 }
226
227 pub fn clear_plots(&mut self) {
229 self.active_plots.clear();
230 self.plot_counter = 0;
231 }
232
233 pub fn update_config(&mut self, config: JupyterPlottingConfig) {
235 self.config = config;
236
237 self.backend = match self.config.output_format {
239 OutputFormat::HTML => JupyterBackend::with_format(OutputFormat::HTML),
240 OutputFormat::PNG => JupyterBackend::with_format(OutputFormat::PNG),
241 OutputFormat::SVG => JupyterBackend::with_format(OutputFormat::SVG),
242 OutputFormat::Base64 => JupyterBackend::with_format(OutputFormat::Base64),
243 OutputFormat::PlotlyJSON => JupyterBackend::with_format(OutputFormat::PlotlyJSON),
244 };
245 }
246
247 pub fn config(&self) -> &JupyterPlottingConfig {
249 &self.config
250 }
251
252 fn cleanup_old_plots(&mut self) {
254 let mut plot_ids: Vec<String> = self.active_plots.keys().cloned().collect();
256 plot_ids.sort();
257
258 while self.active_plots.len() > self.config.max_plots {
259 if let Some(oldest_id) = plot_ids.first() {
260 self.active_plots.remove(oldest_id);
261 plot_ids.remove(0);
262 } else {
263 break;
264 }
265 }
266 }
267
268 pub fn handle_plot_function(
270 &mut self,
271 function_name: &str,
272 args: &[JsonValue],
273 ) -> Result<Option<DisplayData>> {
274 println!(
275 "DEBUG: Handling plot function '{}' with {} args",
276 function_name,
277 args.len()
278 );
279
280 let mut figure = Figure::new();
282
283 match function_name {
284 "plot" => {
285 if args.len() >= 2 {
286 let x_data = self.extract_numeric_array(&args[0])?;
288 let y_data = self.extract_numeric_array(&args[1])?;
289
290 if x_data.len() == y_data.len() {
291 let line_plot =
292 runmat_plot::plots::LinePlot::new(x_data, y_data).map_err(|e| {
293 KernelError::Execution(format!("Failed to create line plot: {e}"))
294 })?;
295 figure.add_line_plot(line_plot);
296 } else {
297 return Err(KernelError::Execution(
298 "X and Y data must have the same length".to_string(),
299 ));
300 }
301 }
302 }
303 "scatter" => {
304 if args.len() >= 2 {
305 let x_data = self.extract_numeric_array(&args[0])?;
306 let y_data = self.extract_numeric_array(&args[1])?;
307
308 if x_data.len() == y_data.len() {
309 let scatter_plot = runmat_plot::plots::ScatterPlot::new(x_data, y_data)
310 .map_err(KernelError::Execution)?;
311 figure.add_scatter_plot(scatter_plot);
312 } else {
313 return Err(KernelError::Execution(
314 "X and Y data must have the same length".to_string(),
315 ));
316 }
317 }
318 }
319 "bar" => {
320 if !args.is_empty() {
321 let y_data = self.extract_numeric_array(&args[0])?;
322 let x_labels: Vec<String> = (0..y_data.len()).map(|i| format!("{i}")).collect();
323
324 let bar_chart = runmat_plot::plots::BarChart::new(x_labels, y_data)
325 .map_err(KernelError::Execution)?;
326 figure.add_bar_chart(bar_chart);
327 }
328 }
329 "hist" => {
330 if !args.is_empty() {
331 let data = self.extract_numeric_array(&args[0])?;
332 let bins = if args.len() > 1 {
333 self.extract_number(&args[1])? as usize
334 } else {
335 20
336 };
337
338 let histogram = runmat_plot::plots::Histogram::new(data, bins)
339 .map_err(KernelError::Execution)?;
340 figure.add_histogram(histogram);
341 }
342 }
343 _ => {
344 return Err(KernelError::Execution(format!(
345 "Unknown plot function: {function_name}"
346 )));
347 }
348 }
349
350 self.register_plot(figure)
352 }
353
354 fn extract_numeric_array(&self, value: &JsonValue) -> Result<Vec<f64>> {
356 match value {
357 JsonValue::Array(arr) => {
358 let mut result = Vec::new();
359 for item in arr {
360 if let Some(num) = item.as_f64() {
361 result.push(num);
362 } else if let Some(num) = item.as_i64() {
363 result.push(num as f64);
364 } else {
365 return Err(KernelError::Execution(
366 "Array must contain only numbers".to_string(),
367 ));
368 }
369 }
370 Ok(result)
371 }
372 JsonValue::Number(num) => {
373 if let Some(val) = num.as_f64() {
374 Ok(vec![val])
375 } else {
376 Err(KernelError::Execution("Invalid number format".to_string()))
377 }
378 }
379 _ => Err(KernelError::Execution(
380 "Expected array or number".to_string(),
381 )),
382 }
383 }
384
385 fn extract_number(&self, value: &JsonValue) -> Result<f64> {
387 match value {
388 JsonValue::Number(num) => num
389 .as_f64()
390 .ok_or_else(|| KernelError::Execution("Invalid number format".to_string())),
391 _ => Err(KernelError::Execution("Expected number".to_string())),
392 }
393 }
394}
395
396impl Default for JupyterPlottingManager {
397 fn default() -> Self {
398 Self::new()
399 }
400}
401
402pub trait JupyterPlottingExtension {
404 fn handle_jupyter_plot(
406 &mut self,
407 function_name: &str,
408 args: &[JsonValue],
409 ) -> Result<Option<DisplayData>>;
410
411 fn plotting_manager(&mut self) -> &mut JupyterPlottingManager;
413}
414
415#[cfg(test)]
419mod tests {
420 use super::*;
421
422 #[test]
423 fn test_jupyter_plotting_manager_creation() {
424 let manager = JupyterPlottingManager::new();
425 assert_eq!(manager.config.output_format, OutputFormat::HTML);
426 assert!(manager.config.auto_display);
427 assert_eq!(manager.active_plots.len(), 0);
428 }
429
430 #[test]
431 fn test_config_update() {
432 let mut manager = JupyterPlottingManager::new();
433
434 let new_config = JupyterPlottingConfig {
435 output_format: OutputFormat::SVG,
436 auto_display: false,
437 max_plots: 50,
438 inline_display: false,
439 image_width: 1024,
440 image_height: 768,
441 };
442
443 manager.update_config(new_config.clone());
444 assert_eq!(manager.config.output_format, OutputFormat::SVG);
445 assert!(!manager.config.auto_display);
446 assert_eq!(manager.config.max_plots, 50);
447 }
448
449 #[test]
450 fn test_plot_management() {
451 let mut manager = JupyterPlottingManager::new();
452 let figure = Figure::new().with_title("Test Plot");
453
454 let display_data = manager.register_plot(figure).unwrap();
456 assert!(display_data.is_some());
457 assert_eq!(manager.active_plots.len(), 1);
458 assert_eq!(manager.list_plots().len(), 1);
459
460 manager.clear_plots();
462 assert_eq!(manager.active_plots.len(), 0);
463 assert_eq!(manager.plot_counter, 0);
464 }
465
466 #[test]
467 fn test_extract_numeric_array() {
468 let manager = JupyterPlottingManager::new();
469
470 let json_array = JsonValue::Array(vec![
471 JsonValue::Number(serde_json::Number::from(1)),
472 JsonValue::Number(serde_json::Number::from(2)),
473 JsonValue::Number(serde_json::Number::from(3)),
474 ]);
475
476 let result = manager.extract_numeric_array(&json_array).unwrap();
477 assert_eq!(result, vec![1.0, 2.0, 3.0]);
478 }
479
480 #[test]
481 fn test_plot_function_handling() {
482 let mut manager = JupyterPlottingManager::new();
483
484 let x_data = JsonValue::Array(vec![
485 JsonValue::Number(serde_json::Number::from(1)),
486 JsonValue::Number(serde_json::Number::from(2)),
487 JsonValue::Number(serde_json::Number::from(3)),
488 ]);
489
490 let y_data = JsonValue::Array(vec![
491 JsonValue::Number(serde_json::Number::from(2)),
492 JsonValue::Number(serde_json::Number::from(4)),
493 JsonValue::Number(serde_json::Number::from(6)),
494 ]);
495
496 let result = manager
497 .handle_plot_function("plot", &[x_data, y_data])
498 .unwrap();
499 assert!(result.is_some());
500 assert_eq!(manager.active_plots.len(), 1);
501 }
502}