1#[cfg(feature = "jupyter")]
7use crate::{KernelError, Result};
8#[cfg(feature = "jupyter")]
9use runmat_plot::jupyter::{JupyterBackend, OutputFormat};
10#[cfg(feature = "jupyter")]
11use runmat_plot::plots::Figure;
12#[cfg(feature = "jupyter")]
13use serde_json::Value as JsonValue;
14#[cfg(feature = "jupyter")]
15use std::collections::HashMap;
16
17#[cfg(feature = "jupyter")]
19pub struct JupyterPlottingManager {
20 backend: JupyterBackend,
22 config: JupyterPlottingConfig,
24 active_plots: HashMap<String, Figure>,
26 plot_counter: u64,
28}
29
30#[cfg(feature = "jupyter")]
32#[derive(Debug, Clone)]
33pub struct JupyterPlottingConfig {
34 pub output_format: OutputFormat,
36 pub auto_display: bool,
38 pub max_plots: usize,
40 pub inline_display: bool,
42 pub image_width: u32,
44 pub image_height: u32,
46}
47
48#[derive(Debug, Clone)]
50pub struct DisplayData {
51 pub data: HashMap<String, JsonValue>,
53 pub metadata: HashMap<String, JsonValue>,
55 pub transient: HashMap<String, JsonValue>,
57}
58
59#[cfg(feature = "jupyter")]
60impl Default for JupyterPlottingConfig {
61 fn default() -> Self {
62 Self {
63 output_format: OutputFormat::HTML,
64 auto_display: true,
65 max_plots: 100,
66 inline_display: true,
67 image_width: 800,
68 image_height: 600,
69 }
70 }
71}
72
73#[cfg(feature = "jupyter")]
74impl JupyterPlottingManager {
75 pub fn new() -> Self {
77 Self::with_config(JupyterPlottingConfig::default())
78 }
79
80 pub fn with_config(config: JupyterPlottingConfig) -> Self {
82 let backend = match config.output_format {
83 OutputFormat::HTML => JupyterBackend::with_format(OutputFormat::HTML),
84 OutputFormat::PNG => JupyterBackend::with_format(OutputFormat::PNG),
85 OutputFormat::SVG => JupyterBackend::with_format(OutputFormat::SVG),
86 OutputFormat::Base64 => JupyterBackend::with_format(OutputFormat::Base64),
87 OutputFormat::PlotlyJSON => JupyterBackend::with_format(OutputFormat::PlotlyJSON),
88 };
89
90 Self {
91 backend,
92 config,
93 active_plots: HashMap::new(),
94 plot_counter: 0,
95 }
96 }
97
98 pub fn register_plot(&mut self, mut figure: Figure) -> Result<Option<DisplayData>> {
100 self.plot_counter += 1;
101 let plot_id = format!("plot_{}", self.plot_counter);
102
103 self.active_plots.insert(plot_id.clone(), figure.clone());
105
106 if self.active_plots.len() > self.config.max_plots {
108 self.cleanup_old_plots();
109 }
110
111 if self.config.auto_display && self.config.inline_display {
113 let display_data = self.create_display_data(&mut figure)?;
114 Ok(Some(display_data))
115 } else {
116 Ok(None)
117 }
118 }
119
120 pub fn create_display_data(&mut self, figure: &mut Figure) -> Result<DisplayData> {
122 let mut data = HashMap::new();
123 let mut metadata = HashMap::new();
124
125 match self.config.output_format {
127 OutputFormat::HTML => {
128 let html_content = self
129 .backend
130 .display_figure(figure)
131 .map_err(|e| KernelError::Execution(format!("HTML generation failed: {e}")))?;
132
133 data.insert("text/html".to_string(), JsonValue::String(html_content));
134 metadata.insert(
135 "text/html".to_string(),
136 JsonValue::Object({
137 let mut meta = serde_json::Map::new();
138 meta.insert("isolated".to_string(), JsonValue::Bool(true));
139 meta.insert(
140 "width".to_string(),
141 JsonValue::Number(self.config.image_width.into()),
142 );
143 meta.insert(
144 "height".to_string(),
145 JsonValue::Number(self.config.image_height.into()),
146 );
147 meta
148 }),
149 );
150 }
151 OutputFormat::PNG => {
152 let png_content = self
153 .backend
154 .display_figure(figure)
155 .map_err(|e| KernelError::Execution(format!("PNG generation failed: {e}")))?;
156
157 data.insert("text/html".to_string(), JsonValue::String(png_content));
158 }
159 OutputFormat::SVG => {
160 let svg_content = self
161 .backend
162 .display_figure(figure)
163 .map_err(|e| KernelError::Execution(format!("SVG generation failed: {e}")))?;
164
165 data.insert("image/svg+xml".to_string(), JsonValue::String(svg_content));
166 metadata.insert(
167 "image/svg+xml".to_string(),
168 JsonValue::Object({
169 let mut meta = serde_json::Map::new();
170 meta.insert("isolated".to_string(), JsonValue::Bool(true));
171 meta
172 }),
173 );
174 }
175 OutputFormat::Base64 => {
176 let base64_content = self.backend.display_figure(figure).map_err(|e| {
177 KernelError::Execution(format!("Base64 generation failed: {e}"))
178 })?;
179
180 data.insert("text/html".to_string(), JsonValue::String(base64_content));
181 }
182 OutputFormat::PlotlyJSON => {
183 let plotly_content = self.backend.display_figure(figure).map_err(|e| {
184 KernelError::Execution(format!("Plotly generation failed: {e}"))
185 })?;
186
187 data.insert("text/html".to_string(), JsonValue::String(plotly_content));
188 metadata.insert(
189 "text/html".to_string(),
190 JsonValue::Object({
191 let mut meta = serde_json::Map::new();
192 meta.insert("isolated".to_string(), JsonValue::Bool(true));
193 meta
194 }),
195 );
196 }
197 }
198
199 let mut transient = HashMap::new();
201 transient.insert(
202 "runmat_plot_id".to_string(),
203 JsonValue::String(format!("plot_{}", self.plot_counter)),
204 );
205 transient.insert(
206 "runmat_version".to_string(),
207 JsonValue::String("0.0.1".to_string()),
208 );
209
210 Ok(DisplayData {
211 data,
212 metadata,
213 transient,
214 })
215 }
216
217 pub fn get_plot(&self, plot_id: &str) -> Option<&Figure> {
219 self.active_plots.get(plot_id)
220 }
221
222 pub fn list_plots(&self) -> Vec<String> {
224 self.active_plots.keys().cloned().collect()
225 }
226
227 pub fn clear_plots(&mut self) {
229 self.active_plots.clear();
230 self.plot_counter = 0;
231 }
232
233 pub fn update_config(&mut self, config: JupyterPlottingConfig) {
235 self.config = config;
236
237 self.backend = match self.config.output_format {
239 OutputFormat::HTML => JupyterBackend::with_format(OutputFormat::HTML),
240 OutputFormat::PNG => JupyterBackend::with_format(OutputFormat::PNG),
241 OutputFormat::SVG => JupyterBackend::with_format(OutputFormat::SVG),
242 OutputFormat::Base64 => JupyterBackend::with_format(OutputFormat::Base64),
243 OutputFormat::PlotlyJSON => JupyterBackend::with_format(OutputFormat::PlotlyJSON),
244 };
245 }
246
247 pub fn config(&self) -> &JupyterPlottingConfig {
249 &self.config
250 }
251
252 fn cleanup_old_plots(&mut self) {
254 let mut plot_ids: Vec<String> = self.active_plots.keys().cloned().collect();
256 plot_ids.sort();
257
258 while self.active_plots.len() > self.config.max_plots {
259 if let Some(oldest_id) = plot_ids.first() {
260 self.active_plots.remove(oldest_id);
261 plot_ids.remove(0);
262 } else {
263 break;
264 }
265 }
266 }
267
268 pub fn handle_plot_function(
270 &mut self,
271 function_name: &str,
272 args: &[JsonValue],
273 ) -> Result<Option<DisplayData>> {
274 println!(
275 "DEBUG: Handling plot function '{}' with {} args",
276 function_name,
277 args.len()
278 );
279
280 let mut figure = Figure::new();
282
283 match function_name {
284 "plot" => {
285 if args.len() >= 2 {
286 let x_data = self.extract_numeric_array(&args[0])?;
288 let y_data = self.extract_numeric_array(&args[1])?;
289
290 if x_data.len() == y_data.len() {
291 let line_plot =
292 runmat_plot::plots::LinePlot::new(x_data, y_data).map_err(|e| {
293 KernelError::Execution(format!("Failed to create line plot: {e}"))
294 })?;
295 figure.add_line_plot(line_plot);
296 } else {
297 return Err(KernelError::Execution(
298 "X and Y data must have the same length".to_string(),
299 ));
300 }
301 }
302 }
303 "scatter" => {
304 if args.len() >= 2 {
305 let x_data = self.extract_numeric_array(&args[0])?;
306 let y_data = self.extract_numeric_array(&args[1])?;
307
308 if x_data.len() == y_data.len() {
309 let scatter_plot = runmat_plot::plots::ScatterPlot::new(x_data, y_data)
310 .map_err(KernelError::Execution)?;
311 figure.add_scatter_plot(scatter_plot);
312 } else {
313 return Err(KernelError::Execution(
314 "X and Y data must have the same length".to_string(),
315 ));
316 }
317 }
318 }
319 "bar" => {
320 if !args.is_empty() {
321 let y_data = self.extract_numeric_array(&args[0])?;
322 let x_labels: Vec<String> = (0..y_data.len()).map(|i| format!("{i}")).collect();
323
324 let bar_chart = runmat_plot::plots::BarChart::new(x_labels, y_data)
325 .map_err(KernelError::Execution)?;
326 figure.add_bar_chart(bar_chart);
327 }
328 }
329 "hist" => {
330 if !args.is_empty() {
331 let data = self.extract_numeric_array(&args[0])?;
332 let bins = if args.len() > 1 {
333 self.extract_number(&args[1])? as usize
334 } else {
335 20
336 };
337
338 let (labels, counts) = self.build_histogram_series(&data, bins)?;
339 let histogram = runmat_plot::plots::BarChart::new(labels, counts)
340 .map_err(KernelError::Execution)?;
341 figure.add_bar_chart(histogram);
342 }
343 }
344 _ => {
345 return Err(KernelError::Execution(format!(
346 "Unknown plot function: {function_name}"
347 )));
348 }
349 }
350
351 self.register_plot(figure)
353 }
354
355 fn extract_numeric_array(&self, value: &JsonValue) -> Result<Vec<f64>> {
357 match value {
358 JsonValue::Array(arr) => {
359 let mut result = Vec::new();
360 for item in arr {
361 if let Some(num) = item.as_f64() {
362 result.push(num);
363 } else if let Some(num) = item.as_i64() {
364 result.push(num as f64);
365 } else {
366 return Err(KernelError::Execution(
367 "Array must contain only numbers".to_string(),
368 ));
369 }
370 }
371 Ok(result)
372 }
373 JsonValue::Number(num) => {
374 if let Some(val) = num.as_f64() {
375 Ok(vec![val])
376 } else {
377 Err(KernelError::Execution("Invalid number format".to_string()))
378 }
379 }
380 _ => Err(KernelError::Execution(
381 "Expected array or number".to_string(),
382 )),
383 }
384 }
385
386 fn build_histogram_series(&self, data: &[f64], bins: usize) -> Result<(Vec<String>, Vec<f64>)> {
387 if data.is_empty() {
388 return Err(KernelError::Execution(
389 "Histogram requires at least one data point".to_string(),
390 ));
391 }
392 let bins = bins.max(1);
393 let mut min_val = f64::INFINITY;
394 let mut max_val = f64::NEG_INFINITY;
395 for &value in data {
396 if value.is_finite() {
397 if value < min_val {
398 min_val = value;
399 }
400 if value > max_val {
401 max_val = value;
402 }
403 }
404 }
405 if !min_val.is_finite() || !max_val.is_finite() {
406 return Err(KernelError::Execution(
407 "Histogram data must be finite".to_string(),
408 ));
409 }
410 let span = (max_val - min_val).max(1e-9);
411 let bucket_width = span / bins as f64;
412 let mut counts = vec![0f64; bins];
413 for &value in data {
414 if !value.is_finite() {
415 continue;
416 }
417 let mut idx = ((value - min_val) / bucket_width).floor() as isize;
418 if idx < 0 {
419 idx = 0;
420 }
421 let idx = idx as usize;
422 if idx >= bins {
423 counts[bins - 1] += 1.0;
424 } else {
425 counts[idx] += 1.0;
426 }
427 }
428 let mut labels = Vec::with_capacity(bins);
429 for i in 0..bins {
430 let start = min_val + bucket_width * i as f64;
431 let end = start + bucket_width;
432 labels.push(format!("{start:.3}-{end:.3}"));
433 }
434 Ok((labels, counts))
435 }
436
437 fn extract_number(&self, value: &JsonValue) -> Result<f64> {
439 match value {
440 JsonValue::Number(num) => num
441 .as_f64()
442 .ok_or_else(|| KernelError::Execution("Invalid number format".to_string())),
443 _ => Err(KernelError::Execution("Expected number".to_string())),
444 }
445 }
446}
447
448impl Default for JupyterPlottingManager {
449 fn default() -> Self {
450 Self::new()
451 }
452}
453
454pub trait JupyterPlottingExtension {
456 fn handle_jupyter_plot(
458 &mut self,
459 function_name: &str,
460 args: &[JsonValue],
461 ) -> Result<Option<DisplayData>>;
462
463 fn plotting_manager(&mut self) -> &mut JupyterPlottingManager;
465}
466
467#[cfg(test)]
471mod tests {
472 use super::*;
473
474 #[test]
475 fn test_jupyter_plotting_manager_creation() {
476 let manager = JupyterPlottingManager::new();
477 assert_eq!(manager.config.output_format, OutputFormat::HTML);
478 assert!(manager.config.auto_display);
479 assert_eq!(manager.active_plots.len(), 0);
480 }
481
482 #[test]
483 fn test_config_update() {
484 let mut manager = JupyterPlottingManager::new();
485
486 let new_config = JupyterPlottingConfig {
487 output_format: OutputFormat::SVG,
488 auto_display: false,
489 max_plots: 50,
490 inline_display: false,
491 image_width: 1024,
492 image_height: 768,
493 };
494
495 manager.update_config(new_config.clone());
496 assert_eq!(manager.config.output_format, OutputFormat::SVG);
497 assert!(!manager.config.auto_display);
498 assert_eq!(manager.config.max_plots, 50);
499 }
500
501 #[test]
502 fn test_plot_management() {
503 let mut manager = JupyterPlottingManager::new();
504 let figure = Figure::new().with_title("Test Plot");
505
506 let display_data = manager.register_plot(figure).unwrap();
508 assert!(display_data.is_some());
509 assert_eq!(manager.active_plots.len(), 1);
510 assert_eq!(manager.list_plots().len(), 1);
511
512 manager.clear_plots();
514 assert_eq!(manager.active_plots.len(), 0);
515 assert_eq!(manager.plot_counter, 0);
516 }
517
518 #[test]
519 fn test_extract_numeric_array() {
520 let manager = JupyterPlottingManager::new();
521
522 let json_array = JsonValue::Array(vec![
523 JsonValue::Number(serde_json::Number::from(1)),
524 JsonValue::Number(serde_json::Number::from(2)),
525 JsonValue::Number(serde_json::Number::from(3)),
526 ]);
527
528 let result = manager.extract_numeric_array(&json_array).unwrap();
529 assert_eq!(result, vec![1.0, 2.0, 3.0]);
530 }
531
532 #[test]
533 fn test_plot_function_handling() {
534 let mut manager = JupyterPlottingManager::new();
535
536 let x_data = JsonValue::Array(vec![
537 JsonValue::Number(serde_json::Number::from(1)),
538 JsonValue::Number(serde_json::Number::from(2)),
539 JsonValue::Number(serde_json::Number::from(3)),
540 ]);
541
542 let y_data = JsonValue::Array(vec![
543 JsonValue::Number(serde_json::Number::from(2)),
544 JsonValue::Number(serde_json::Number::from(4)),
545 JsonValue::Number(serde_json::Number::from(6)),
546 ]);
547
548 let result = manager
549 .handle_plot_function("plot", &[x_data, y_data])
550 .unwrap();
551 assert!(result.is_some());
552 assert_eq!(manager.active_plots.len(), 1);
553 }
554}