#[cfg(feature = "jupyter")]
use crate::{KernelError, Result};
#[cfg(feature = "jupyter")]
use runmat_plot::jupyter::{JupyterBackend, OutputFormat};
#[cfg(feature = "jupyter")]
use runmat_plot::plots::Figure;
#[cfg(feature = "jupyter")]
use serde_json::Value as JsonValue;
#[cfg(feature = "jupyter")]
use std::collections::HashMap;
#[cfg(feature = "jupyter")]
pub struct JupyterPlottingManager {
backend: JupyterBackend,
config: JupyterPlottingConfig,
active_plots: HashMap<String, Figure>,
plot_counter: u64,
}
#[cfg(feature = "jupyter")]
#[derive(Debug, Clone)]
pub struct JupyterPlottingConfig {
pub output_format: OutputFormat,
pub auto_display: bool,
pub max_plots: usize,
pub inline_display: bool,
pub image_width: u32,
pub image_height: u32,
}
#[derive(Debug, Clone)]
pub struct DisplayData {
pub data: HashMap<String, JsonValue>,
pub metadata: HashMap<String, JsonValue>,
pub transient: HashMap<String, JsonValue>,
}
#[cfg(feature = "jupyter")]
impl Default for JupyterPlottingConfig {
fn default() -> Self {
Self {
output_format: OutputFormat::HTML,
auto_display: true,
max_plots: 100,
inline_display: true,
image_width: 800,
image_height: 600,
}
}
}
#[cfg(feature = "jupyter")]
impl JupyterPlottingManager {
pub fn new() -> Self {
Self::with_config(JupyterPlottingConfig::default())
}
pub fn with_config(config: JupyterPlottingConfig) -> Self {
let backend = match config.output_format {
OutputFormat::HTML => JupyterBackend::with_format(OutputFormat::HTML),
OutputFormat::PNG => JupyterBackend::with_format(OutputFormat::PNG),
OutputFormat::SVG => JupyterBackend::with_format(OutputFormat::SVG),
OutputFormat::Base64 => JupyterBackend::with_format(OutputFormat::Base64),
OutputFormat::PlotlyJSON => JupyterBackend::with_format(OutputFormat::PlotlyJSON),
};
Self {
backend,
config,
active_plots: HashMap::new(),
plot_counter: 0,
}
}
pub fn register_plot(&mut self, mut figure: Figure) -> Result<Option<DisplayData>> {
self.plot_counter += 1;
let plot_id = format!("plot_{}", self.plot_counter);
self.active_plots.insert(plot_id.clone(), figure.clone());
if self.active_plots.len() > self.config.max_plots {
self.cleanup_old_plots();
}
if self.config.auto_display && self.config.inline_display {
let display_data = self.create_display_data(&mut figure)?;
Ok(Some(display_data))
} else {
Ok(None)
}
}
pub fn create_display_data(&mut self, figure: &mut Figure) -> Result<DisplayData> {
let mut data = HashMap::new();
let mut metadata = HashMap::new();
match self.config.output_format {
OutputFormat::HTML => {
let html_content = self
.backend
.display_figure(figure)
.map_err(|e| KernelError::Execution(format!("HTML generation failed: {e}")))?;
data.insert("text/html".to_string(), JsonValue::String(html_content));
metadata.insert(
"text/html".to_string(),
JsonValue::Object({
let mut meta = serde_json::Map::new();
meta.insert("isolated".to_string(), JsonValue::Bool(true));
meta.insert(
"width".to_string(),
JsonValue::Number(self.config.image_width.into()),
);
meta.insert(
"height".to_string(),
JsonValue::Number(self.config.image_height.into()),
);
meta
}),
);
}
OutputFormat::PNG => {
let png_content = self
.backend
.display_figure(figure)
.map_err(|e| KernelError::Execution(format!("PNG generation failed: {e}")))?;
data.insert("text/html".to_string(), JsonValue::String(png_content));
}
OutputFormat::SVG => {
let svg_content = self
.backend
.display_figure(figure)
.map_err(|e| KernelError::Execution(format!("SVG generation failed: {e}")))?;
data.insert("image/svg+xml".to_string(), JsonValue::String(svg_content));
metadata.insert(
"image/svg+xml".to_string(),
JsonValue::Object({
let mut meta = serde_json::Map::new();
meta.insert("isolated".to_string(), JsonValue::Bool(true));
meta
}),
);
}
OutputFormat::Base64 => {
let base64_content = self.backend.display_figure(figure).map_err(|e| {
KernelError::Execution(format!("Base64 generation failed: {e}"))
})?;
data.insert("text/html".to_string(), JsonValue::String(base64_content));
}
OutputFormat::PlotlyJSON => {
let plotly_content = self.backend.display_figure(figure).map_err(|e| {
KernelError::Execution(format!("Plotly generation failed: {e}"))
})?;
data.insert("text/html".to_string(), JsonValue::String(plotly_content));
metadata.insert(
"text/html".to_string(),
JsonValue::Object({
let mut meta = serde_json::Map::new();
meta.insert("isolated".to_string(), JsonValue::Bool(true));
meta
}),
);
}
}
let mut transient = HashMap::new();
transient.insert(
"runmat_plot_id".to_string(),
JsonValue::String(format!("plot_{}", self.plot_counter)),
);
transient.insert(
"runmat_version".to_string(),
JsonValue::String("0.0.1".to_string()),
);
Ok(DisplayData {
data,
metadata,
transient,
})
}
pub fn get_plot(&self, plot_id: &str) -> Option<&Figure> {
self.active_plots.get(plot_id)
}
pub fn list_plots(&self) -> Vec<String> {
self.active_plots.keys().cloned().collect()
}
pub fn clear_plots(&mut self) {
self.active_plots.clear();
self.plot_counter = 0;
}
pub fn update_config(&mut self, config: JupyterPlottingConfig) {
self.config = config;
self.backend = match self.config.output_format {
OutputFormat::HTML => JupyterBackend::with_format(OutputFormat::HTML),
OutputFormat::PNG => JupyterBackend::with_format(OutputFormat::PNG),
OutputFormat::SVG => JupyterBackend::with_format(OutputFormat::SVG),
OutputFormat::Base64 => JupyterBackend::with_format(OutputFormat::Base64),
OutputFormat::PlotlyJSON => JupyterBackend::with_format(OutputFormat::PlotlyJSON),
};
}
pub fn config(&self) -> &JupyterPlottingConfig {
&self.config
}
fn cleanup_old_plots(&mut self) {
let mut plot_ids: Vec<String> = self.active_plots.keys().cloned().collect();
plot_ids.sort();
while self.active_plots.len() > self.config.max_plots {
if let Some(oldest_id) = plot_ids.first() {
self.active_plots.remove(oldest_id);
plot_ids.remove(0);
} else {
break;
}
}
}
pub fn handle_plot_function(
&mut self,
function_name: &str,
args: &[JsonValue],
) -> Result<Option<DisplayData>> {
println!(
"DEBUG: Handling plot function '{}' with {} args",
function_name,
args.len()
);
let mut figure = Figure::new();
match function_name {
"plot" => {
if args.len() >= 2 {
let x_data = self.extract_numeric_array(&args[0])?;
let y_data = self.extract_numeric_array(&args[1])?;
if x_data.len() == y_data.len() {
let line_plot =
runmat_plot::plots::LinePlot::new(x_data, y_data).map_err(|e| {
KernelError::Execution(format!("Failed to create line plot: {e}"))
})?;
figure.add_line_plot(line_plot);
} else {
return Err(KernelError::Execution(
"X and Y data must have the same length".to_string(),
));
}
}
}
"scatter" => {
if args.len() >= 2 {
let x_data = self.extract_numeric_array(&args[0])?;
let y_data = self.extract_numeric_array(&args[1])?;
if x_data.len() == y_data.len() {
let scatter_plot = runmat_plot::plots::ScatterPlot::new(x_data, y_data)
.map_err(KernelError::Execution)?;
figure.add_scatter_plot(scatter_plot);
} else {
return Err(KernelError::Execution(
"X and Y data must have the same length".to_string(),
));
}
}
}
"bar" => {
if !args.is_empty() {
let y_data = self.extract_numeric_array(&args[0])?;
let x_labels: Vec<String> = (0..y_data.len()).map(|i| format!("{i}")).collect();
let bar_chart = runmat_plot::plots::BarChart::new(x_labels, y_data)
.map_err(KernelError::Execution)?;
figure.add_bar_chart(bar_chart);
}
}
"hist" => {
if !args.is_empty() {
let data = self.extract_numeric_array(&args[0])?;
let bins = if args.len() > 1 {
self.extract_number(&args[1])? as usize
} else {
20
};
let (labels, counts) = self.build_histogram_series(&data, bins)?;
let histogram = runmat_plot::plots::BarChart::new(labels, counts)
.map_err(KernelError::Execution)?;
figure.add_bar_chart(histogram);
}
}
_ => {
return Err(KernelError::Execution(format!(
"Unknown plot function: {function_name}"
)));
}
}
self.register_plot(figure)
}
fn extract_numeric_array(&self, value: &JsonValue) -> Result<Vec<f64>> {
match value {
JsonValue::Array(arr) => {
let mut result = Vec::new();
for item in arr {
if let Some(num) = item.as_f64() {
result.push(num);
} else if let Some(num) = item.as_i64() {
result.push(num as f64);
} else {
return Err(KernelError::Execution(
"Array must contain only numbers".to_string(),
));
}
}
Ok(result)
}
JsonValue::Number(num) => {
if let Some(val) = num.as_f64() {
Ok(vec![val])
} else {
Err(KernelError::Execution("Invalid number format".to_string()))
}
}
_ => Err(KernelError::Execution(
"Expected array or number".to_string(),
)),
}
}
fn build_histogram_series(&self, data: &[f64], bins: usize) -> Result<(Vec<String>, Vec<f64>)> {
if data.is_empty() {
return Err(KernelError::Execution(
"Histogram requires at least one data point".to_string(),
));
}
let bins = bins.max(1);
let mut min_val = f64::INFINITY;
let mut max_val = f64::NEG_INFINITY;
for &value in data {
if value.is_finite() {
if value < min_val {
min_val = value;
}
if value > max_val {
max_val = value;
}
}
}
if !min_val.is_finite() || !max_val.is_finite() {
return Err(KernelError::Execution(
"Histogram data must be finite".to_string(),
));
}
let span = (max_val - min_val).max(1e-9);
let bucket_width = span / bins as f64;
let mut counts = vec![0f64; bins];
for &value in data {
if !value.is_finite() {
continue;
}
let mut idx = ((value - min_val) / bucket_width).floor() as isize;
if idx < 0 {
idx = 0;
}
let idx = idx as usize;
if idx >= bins {
counts[bins - 1] += 1.0;
} else {
counts[idx] += 1.0;
}
}
let mut labels = Vec::with_capacity(bins);
for i in 0..bins {
let start = min_val + bucket_width * i as f64;
let end = start + bucket_width;
labels.push(format!("{start:.3}-{end:.3}"));
}
Ok((labels, counts))
}
fn extract_number(&self, value: &JsonValue) -> Result<f64> {
match value {
JsonValue::Number(num) => num
.as_f64()
.ok_or_else(|| KernelError::Execution("Invalid number format".to_string())),
_ => Err(KernelError::Execution("Expected number".to_string())),
}
}
}
impl Default for JupyterPlottingManager {
fn default() -> Self {
Self::new()
}
}
pub trait JupyterPlottingExtension {
fn handle_jupyter_plot(
&mut self,
function_name: &str,
args: &[JsonValue],
) -> Result<Option<DisplayData>>;
fn plotting_manager(&mut self) -> &mut JupyterPlottingManager;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_jupyter_plotting_manager_creation() {
let manager = JupyterPlottingManager::new();
assert_eq!(manager.config.output_format, OutputFormat::HTML);
assert!(manager.config.auto_display);
assert_eq!(manager.active_plots.len(), 0);
}
#[test]
fn test_config_update() {
let mut manager = JupyterPlottingManager::new();
let new_config = JupyterPlottingConfig {
output_format: OutputFormat::SVG,
auto_display: false,
max_plots: 50,
inline_display: false,
image_width: 1024,
image_height: 768,
};
manager.update_config(new_config.clone());
assert_eq!(manager.config.output_format, OutputFormat::SVG);
assert!(!manager.config.auto_display);
assert_eq!(manager.config.max_plots, 50);
}
#[test]
fn test_plot_management() {
let mut manager = JupyterPlottingManager::new();
let figure = Figure::new().with_title("Test Plot");
let display_data = manager.register_plot(figure).unwrap();
assert!(display_data.is_some());
assert_eq!(manager.active_plots.len(), 1);
assert_eq!(manager.list_plots().len(), 1);
manager.clear_plots();
assert_eq!(manager.active_plots.len(), 0);
assert_eq!(manager.plot_counter, 0);
}
#[test]
fn test_extract_numeric_array() {
let manager = JupyterPlottingManager::new();
let json_array = JsonValue::Array(vec![
JsonValue::Number(serde_json::Number::from(1)),
JsonValue::Number(serde_json::Number::from(2)),
JsonValue::Number(serde_json::Number::from(3)),
]);
let result = manager.extract_numeric_array(&json_array).unwrap();
assert_eq!(result, vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_plot_function_handling() {
let mut manager = JupyterPlottingManager::new();
let x_data = JsonValue::Array(vec![
JsonValue::Number(serde_json::Number::from(1)),
JsonValue::Number(serde_json::Number::from(2)),
JsonValue::Number(serde_json::Number::from(3)),
]);
let y_data = JsonValue::Array(vec![
JsonValue::Number(serde_json::Number::from(2)),
JsonValue::Number(serde_json::Number::from(4)),
JsonValue::Number(serde_json::Number::from(6)),
]);
let result = manager
.handle_plot_function("plot", &[x_data, y_data])
.unwrap();
assert!(result.is_some());
assert_eq!(manager.active_plots.len(), 1);
}
}