1use crate::graph::traits::GraphQuery;
19use crate::graph::Graph;
20use crate::node::NodeIndex;
21use std::collections::HashMap;
22
23#[derive(Debug, Clone)]
25pub struct SvgOptions {
26 pub width: u32,
28 pub height: u32,
30 pub node_radius: f64,
32 pub node_fill: String,
34 pub node_stroke: String,
36 pub node_stroke_width: f64,
38 pub edge_color: String,
40 pub edge_width: f64,
42 pub font_size: f64,
44 pub font_color: String,
46 pub show_labels: bool,
48 pub show_weights: bool,
50 pub layout: LayoutAlgorithm,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum LayoutAlgorithm {
57 ForceDirected,
59 Circular,
61 Hierarchical,
63}
64
65impl Default for SvgOptions {
66 fn default() -> Self {
67 Self {
68 width: 800,
69 height: 600,
70 node_radius: 20.0,
71 node_fill: "#4A90D9".to_string(),
72 node_stroke: "#2C5282".to_string(),
73 node_stroke_width: 2.0,
74 edge_color: "#A0AEC0".to_string(),
75 edge_width: 1.5,
76 font_size: 12.0,
77 font_color: "#2D3748".to_string(),
78 show_labels: true,
79 show_weights: false,
80 layout: LayoutAlgorithm::ForceDirected,
81 }
82 }
83}
84
85impl SvgOptions {
86 pub fn new() -> Self {
88 Self::default()
89 }
90
91 pub fn with_size(mut self, width: u32, height: u32) -> Self {
93 self.width = width;
94 self.height = height;
95 self
96 }
97
98 pub fn with_node_radius(mut self, radius: f64) -> Self {
100 self.node_radius = radius;
101 self
102 }
103
104 pub fn with_labels(mut self, show: bool) -> Self {
106 self.show_labels = show;
107 self
108 }
109
110 pub fn with_layout(mut self, layout: LayoutAlgorithm) -> Self {
112 self.layout = layout;
113 self
114 }
115}
116
117pub fn to_svg<T: std::fmt::Display, E: std::fmt::Display + Clone>(graph: &Graph<T, E>) -> String {
119 to_svg_with_options(graph, &SvgOptions::default())
120}
121
122pub fn to_svg_with_options<T: std::fmt::Display, E: std::fmt::Display + Clone>(
124 graph: &Graph<T, E>,
125 options: &SvgOptions,
126) -> String {
127 let mut output = String::new();
128
129 output.push_str(&format!(
131 r#"<svg xmlns="http://www.w3.org/2000/svg" width="{}" height="{}" viewBox="0 0 {} {}">"#,
132 options.width, options.height, options.width, options.height
133 ));
134 output.push('\n');
135
136 output.push_str(r##"<rect width="100%" height="100%" fill="#FFFFFF"/>"##);
138 output.push('\n');
139
140 let positions = compute_layout(graph, options);
142
143 for edge in graph.edges() {
145 let src = edge.source();
146 let tgt = edge.target();
147 if let (Some(&src_pos), Some(&tgt_pos)) = (positions.get(&src), positions.get(&tgt)) {
148 let (x1, y1) = src_pos;
149 let (x2, y2) = tgt_pos;
150
151 output.push_str(&format!(
152 r#"<line x1="{}" y1="{}" x2="{}" y2="{}" stroke="{}" stroke-width="{}" fill="none"/>"#,
153 x1, y1, x2, y2, options.edge_color, options.edge_width
154 ));
155 output.push('\n');
156
157 draw_arrow(&mut output, x1, y1, x2, y2, options);
159 }
160 }
161
162 for node in graph.nodes() {
164 let idx = node.index();
165 if let Some(&(x, y)) = positions.get(&idx) {
166 output.push_str(&format!(
168 r#"<circle cx="{}" cy="{}" r="{}" fill="{}" stroke="{}" stroke-width="{}"/>"#,
169 x,
170 y,
171 options.node_radius,
172 options.node_fill,
173 options.node_stroke,
174 options.node_stroke_width
175 ));
176 output.push('\n');
177
178 if options.show_labels {
180 let label = format!("{}", node.data());
181 output.push_str(&format!(
182 r#"<text x="{}" y="{}" font-size="{}" fill="{}" text-anchor="middle" dominant-baseline="central">{}</text>"#,
183 x, y, options.font_size, options.font_color, escape_xml(&label)
184 ));
185 output.push('\n');
186 }
187 }
188 }
189
190 output.push_str("</svg>");
191 output
192}
193
194fn compute_layout<T, E: Clone>(
196 graph: &Graph<T, E>,
197 options: &SvgOptions,
198) -> HashMap<NodeIndex, (f64, f64)> {
199 let nodes: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();
200 let n = nodes.len();
201
202 if n == 0 {
203 return HashMap::new();
204 }
205
206 match options.layout {
207 LayoutAlgorithm::Circular => compute_circular_layout(&nodes, options),
208 LayoutAlgorithm::Hierarchical => compute_hierarchical_layout(graph, options),
209 LayoutAlgorithm::ForceDirected => compute_force_directed_layout(graph, &nodes, options),
210 }
211}
212
213fn compute_circular_layout(
215 nodes: &[NodeIndex],
216 options: &SvgOptions,
217) -> HashMap<NodeIndex, (f64, f64)> {
218 let mut positions = HashMap::new();
219 let n = nodes.len();
220 let center_x = options.width as f64 / 2.0;
221 let center_y = options.height as f64 / 2.0;
222 let radius = (options.width.min(options.height) as f64 / 2.0) * 0.8;
223
224 for (i, &node) in nodes.iter().enumerate() {
225 let angle = 2.0 * std::f64::consts::PI * (i as f64) / (n as f64);
226 let x = center_x + radius * angle.cos();
227 let y = center_y + radius * angle.sin();
228 positions.insert(node, (x, y));
229 }
230
231 positions
232}
233
234fn compute_force_directed_layout<T, E>(
236 graph: &Graph<T, E>,
237 nodes: &[NodeIndex],
238 options: &SvgOptions,
239) -> HashMap<NodeIndex, (f64, f64)> {
240 let mut positions = HashMap::new();
241 let n = nodes.len();
242 let center_x = options.width as f64 / 2.0;
243 let center_y = options.height as f64 / 2.0;
244
245 use std::collections::hash_map::DefaultHasher;
247 use std::hash::{Hash, Hasher};
248
249 for &node in nodes.iter() {
250 let mut hasher = DefaultHasher::new();
251 node.hash(&mut hasher);
252 let seed = hasher.finish() as f64;
253 let angle = seed * 0.001;
254 let radius = 50.0 + ((seed as u64) % 200) as f64;
255 let x = center_x + radius * angle.cos();
256 let y = center_y + radius * angle.sin();
257 positions.insert(node, (x, y));
258 }
259
260 let iterations = 50;
262 let repulsion = 1000.0;
263 let attraction = 0.01;
264 let damping = 0.85;
265
266 let mut velocities: HashMap<NodeIndex, (f64, f64)> =
267 nodes.iter().map(|&n| (n, (0.0, 0.0))).collect();
268
269 for _ in 0..iterations {
270 let mut forces: HashMap<NodeIndex, (f64, f64)> =
271 nodes.iter().map(|&n| (n, (0.0, 0.0))).collect();
272
273 for i in 0..n {
275 for j in (i + 1)..n {
276 let ni = nodes[i];
277 let nj = nodes[j];
278 let (xi, yi) = positions[&ni];
279 let (xj, yj) = positions[&nj];
280
281 let dx = xi - xj;
282 let dy = yi - yj;
283 let dist = (dx * dx + dy * dy).sqrt().max(1.0);
284
285 let force = repulsion / (dist * dist);
286 let fx = force * dx / dist;
287 let fy = force * dy / dist;
288
289 let (fix, fiy) = forces.get_mut(&ni).unwrap();
290 *fix += fx;
291 *fiy += fy;
292
293 let (fjx, fjy) = forces.get_mut(&nj).unwrap();
294 *fjx -= fx;
295 *fjy -= fy;
296 }
297 }
298
299 for edge in graph.edges() {
301 let src = edge.source();
302 let tgt = edge.target();
303 if positions.contains_key(&src) && positions.contains_key(&tgt) {
304 let (xs, ys) = positions[&src];
305 let (xt, yt) = positions[&tgt];
306
307 let dx = xt - xs;
308 let dy = yt - ys;
309 let dist = (dx * dx + dy * dy).sqrt().max(1.0);
310
311 let force = attraction * dist;
312 let fx = force * dx / dist;
313 let fy = force * dy / dist;
314
315 let (fsx, fsy) = forces.get_mut(&src).unwrap();
316 *fsx += fx;
317 *fsy += fy;
318
319 let (ftx, fty) = forces.get_mut(&tgt).unwrap();
320 *ftx -= fx;
321 *fty -= fy;
322 }
323 }
324
325 for &node in nodes {
327 let (x, y) = positions[&node];
328 let dx = center_x - x;
329 let dy = center_y - y;
330 let (fx, fy) = forces.get_mut(&node).unwrap();
331 *fx += dx * 0.001;
332 *fy += dy * 0.001;
333 }
334
335 for &node in nodes {
337 let (fx, fy) = forces[&node];
338 let (vx, vy) = velocities.get_mut(&node).unwrap();
339 *vx = (*vx + fx) * damping;
340 *vy = (*vy + fy) * damping;
341
342 let (x, y) = positions.get_mut(&node).unwrap();
343 *x += *vx;
344 *y += *vy;
345
346 let margin = options.node_radius + 5.0;
348 *x = (*x).max(margin).min(options.width as f64 - margin);
349 *y = (*y).max(margin).min(options.height as f64 - margin);
350 }
351 }
352
353 positions
354}
355
356fn compute_hierarchical_layout<T, E: Clone>(
358 graph: &Graph<T, E>,
359 options: &SvgOptions,
360) -> HashMap<NodeIndex, (f64, f64)> {
361 use crate::algorithms::traversal::topological_sort;
362
363 let mut positions = HashMap::new();
364 let nodes_result = topological_sort(graph);
365
366 let nodes = match nodes_result {
368 Ok(n) => n,
369 Err(_) => {
370 return compute_circular_layout(
371 &graph.nodes().map(|n| n.index()).collect::<Vec<_>>(),
372 options,
373 )
374 }
375 };
376
377 if nodes.is_empty() {
378 return compute_circular_layout(
379 &graph.nodes().map(|n| n.index()).collect::<Vec<_>>(),
380 options,
381 );
382 }
383
384 let n = nodes.len();
385 let levels: Vec<Vec<NodeIndex>> = vec![nodes]; let num_levels = levels.len();
387
388 let level_height = options.height as f64 / (num_levels as f64 + 1.0);
389 let node_spacing = options.width as f64 / (n as f64 + 1.0);
390
391 for (level_idx, level_nodes) in levels.iter().enumerate() {
392 let y = level_height * (level_idx as f64 + 1.0);
393 for (node_idx, &node) in level_nodes.iter().enumerate() {
394 let x = node_spacing * (node_idx as f64 + 1.0);
395 positions.insert(node, (x, y));
396 }
397 }
398
399 positions
400}
401
402fn draw_arrow(output: &mut String, x1: f64, y1: f64, x2: f64, y2: f64, options: &SvgOptions) {
404 let arrow_size = 8.0;
405 let angle = (y2 - y1).atan2(x2 - x1);
406 let arrow_angle = std::f64::consts::FRAC_PI_4;
407
408 let dist = ((x2 - x1).powi(2) + (y2 - y1).powi(2)).sqrt();
410 let stop_dist = dist - options.node_radius;
411
412 if stop_dist < 0.0 {
413 return; }
415
416 let x1_adj = x1 + (x2 - x1) * (stop_dist / dist);
417 let y1_adj = y1 + (y2 - y1) * (stop_dist / dist);
418
419 let left_angle = angle + arrow_angle;
421 let x_left = x1_adj - arrow_size * left_angle.cos();
422 let y_left = y1_adj - arrow_size * left_angle.sin();
423
424 let right_angle = angle - arrow_angle;
426 let x_right = x1_adj - arrow_size * right_angle.cos();
427 let y_right = y1_adj - arrow_size * right_angle.sin();
428
429 output.push_str(&format!(
430 r#"<polygon points="{},{} {},{} {},{}" fill="{}" stroke="none"/>"#,
431 x1_adj, y1_adj, x_left, y_left, x_right, y_right, options.edge_color
432 ));
433 output.push('\n');
434}
435
436fn escape_xml(s: &str) -> String {
438 s.replace('&', "&")
439 .replace('<', "<")
440 .replace('>', ">")
441 .replace('"', """)
442 .replace('\'', "'")
443}
444
445pub fn write_svg_to_file(svg: &str, path: &str) -> std::io::Result<()> {
447 std::fs::write(path, svg)
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453 use crate::graph::traits::GraphOps;
454 use crate::graph::Graph;
455
456 #[test]
457 fn test_svg_export_basic() {
458 let mut graph: Graph<String, f64> = Graph::directed();
459 let a = graph.add_node("A".to_string()).unwrap();
460 let b = graph.add_node("B".to_string()).unwrap();
461 graph.add_edge(a, b, 1.0).unwrap();
462
463 let svg = to_svg(&graph);
464 assert!(svg.contains("<svg"));
465 assert!(svg.contains("</svg>"));
466 assert!(svg.contains("<circle"));
467 assert!(svg.contains("<line"));
468 }
469
470 #[test]
471 fn test_svg_options() {
472 let mut graph: Graph<String, f64> = Graph::directed();
473 let a = graph.add_node("A".to_string()).unwrap();
474 let b = graph.add_node("B".to_string()).unwrap();
475 graph.add_edge(a, b, 1.0).unwrap();
476
477 let options = SvgOptions::new()
478 .with_size(400, 300)
479 .with_node_radius(15.0)
480 .with_labels(false);
481
482 let svg = to_svg_with_options(&graph, &options);
483 assert!(svg.contains(r#"width="400""#));
484 assert!(svg.contains(r#"height="300""#));
485 }
486
487 #[test]
488 fn test_svg_empty_graph() {
489 let graph: Graph<String, f64> = Graph::directed();
490 let svg = to_svg(&graph);
491 assert!(svg.contains("<svg"));
492 assert!(svg.contains("</svg>"));
493 }
494}