Skip to main content

god_graph/export/
svg.rs

1//! SVG 可视化导出
2//!
3//! 支持将图导出为 SVG 格式,用于 Web 可视化
4//!
5//! ## 示例
6//!
7//! ```rust,ignore
8//! use god_gragh::graph::Graph;
9//! use god_gragh::export::svg::to_svg;
10//!
11//! let graph: Graph<String, f64> = Graph::directed();
12//! // ... 添加节点和边
13//!
14//! let svg = to_svg(&graph, 800, 600);
15//! std::fs::write("graph.svg", svg).unwrap();
16//! ```
17
18use crate::graph::traits::GraphQuery;
19use crate::graph::Graph;
20use crate::node::NodeIndex;
21use std::collections::HashMap;
22
23/// SVG 可视化选项
24#[derive(Debug, Clone)]
25pub struct SvgOptions {
26    /// SVG 宽度(像素)
27    pub width: u32,
28    /// SVG 高度(像素)
29    pub height: u32,
30    /// 节点半径(像素)
31    pub node_radius: f64,
32    /// 节点填充颜色
33    pub node_fill: String,
34    /// 节点描边颜色
35    pub node_stroke: String,
36    /// 节点描边宽度
37    pub node_stroke_width: f64,
38    /// 边颜色
39    pub edge_color: String,
40    /// 边宽度
41    pub edge_width: f64,
42    /// 字体大小
43    pub font_size: f64,
44    /// 字体颜色
45    pub font_color: String,
46    /// 是否显示节点标签
47    pub show_labels: bool,
48    /// 是否显示边权重
49    pub show_weights: bool,
50    /// 布局算法(目前仅支持 force-directed)
51    pub layout: LayoutAlgorithm,
52}
53
54/// 布局算法
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum LayoutAlgorithm {
57    /// 力导向布局
58    ForceDirected,
59    /// 圆形布局
60    Circular,
61    /// 层次布局(仅适用于 DAG)
62    Hierarchical,
63}
64
65impl Default for SvgOptions {
66    fn default() -> Self {
67        Self {
68            width: 800,
69            height: 600,
70            node_radius: 20.0,
71            node_fill: "#4A90D9".to_string(),
72            node_stroke: "#2C5282".to_string(),
73            node_stroke_width: 2.0,
74            edge_color: "#A0AEC0".to_string(),
75            edge_width: 1.5,
76            font_size: 12.0,
77            font_color: "#2D3748".to_string(),
78            show_labels: true,
79            show_weights: false,
80            layout: LayoutAlgorithm::ForceDirected,
81        }
82    }
83}
84
85impl SvgOptions {
86    /// 创建默认 SVG 选项
87    pub fn new() -> Self {
88        Self::default()
89    }
90
91    /// 设置 SVG 尺寸
92    pub fn with_size(mut self, width: u32, height: u32) -> Self {
93        self.width = width;
94        self.height = height;
95        self
96    }
97
98    /// 设置节点半径
99    pub fn with_node_radius(mut self, radius: f64) -> Self {
100        self.node_radius = radius;
101        self
102    }
103
104    /// 设置是否显示标签
105    pub fn with_labels(mut self, show: bool) -> Self {
106        self.show_labels = show;
107        self
108    }
109
110    /// 设置布局算法
111    pub fn with_layout(mut self, layout: LayoutAlgorithm) -> Self {
112        self.layout = layout;
113        self
114    }
115}
116
117/// 使用默认选项将图导出为 SVG 格式
118pub fn to_svg<T: std::fmt::Display, E: std::fmt::Display + Clone>(graph: &Graph<T, E>) -> String {
119    to_svg_with_options(graph, &SvgOptions::default())
120}
121
122/// 使用自定义选项将图导出为 SVG 格式
123pub fn to_svg_with_options<T: std::fmt::Display, E: std::fmt::Display + Clone>(
124    graph: &Graph<T, E>,
125    options: &SvgOptions,
126) -> String {
127    let mut output = String::new();
128
129    // SVG 头部
130    output.push_str(&format!(
131        r#"<svg xmlns="http://www.w3.org/2000/svg" width="{}" height="{}" viewBox="0 0 {} {}">"#,
132        options.width, options.height, options.width, options.height
133    ));
134    output.push('\n');
135
136    // 背景
137    output.push_str(r##"<rect width="100%" height="100%" fill="#FFFFFF"/>"##);
138    output.push('\n');
139
140    // 计算节点位置
141    let positions = compute_layout(graph, options);
142
143    // 绘制边
144    for edge in graph.edges() {
145        let src = edge.source();
146        let tgt = edge.target();
147        if let (Some(&src_pos), Some(&tgt_pos)) = (positions.get(&src), positions.get(&tgt)) {
148            let (x1, y1) = src_pos;
149            let (x2, y2) = tgt_pos;
150
151            output.push_str(&format!(
152                r#"<line x1="{}" y1="{}" x2="{}" y2="{}" stroke="{}" stroke-width="{}" fill="none"/>"#,
153                x1, y1, x2, y2, options.edge_color, options.edge_width
154            ));
155            output.push('\n');
156
157            // 绘制箭头(有向图)
158            draw_arrow(&mut output, x1, y1, x2, y2, options);
159        }
160    }
161
162    // 绘制节点
163    for node in graph.nodes() {
164        let idx = node.index();
165        if let Some(&(x, y)) = positions.get(&idx) {
166            // 节点圆形
167            output.push_str(&format!(
168                r#"<circle cx="{}" cy="{}" r="{}" fill="{}" stroke="{}" stroke-width="{}"/>"#,
169                x,
170                y,
171                options.node_radius,
172                options.node_fill,
173                options.node_stroke,
174                options.node_stroke_width
175            ));
176            output.push('\n');
177
178            // 节点标签
179            if options.show_labels {
180                let label = format!("{}", node.data());
181                output.push_str(&format!(
182                    r#"<text x="{}" y="{}" font-size="{}" fill="{}" text-anchor="middle" dominant-baseline="central">{}</text>"#,
183                    x, y, options.font_size, options.font_color, escape_xml(&label)
184                ));
185                output.push('\n');
186            }
187        }
188    }
189
190    output.push_str("</svg>");
191    output
192}
193
194/// 计算节点布局位置
195fn compute_layout<T, E: Clone>(
196    graph: &Graph<T, E>,
197    options: &SvgOptions,
198) -> HashMap<NodeIndex, (f64, f64)> {
199    let nodes: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();
200    let n = nodes.len();
201
202    if n == 0 {
203        return HashMap::new();
204    }
205
206    match options.layout {
207        LayoutAlgorithm::Circular => compute_circular_layout(&nodes, options),
208        LayoutAlgorithm::Hierarchical => compute_hierarchical_layout(graph, options),
209        LayoutAlgorithm::ForceDirected => compute_force_directed_layout(graph, &nodes, options),
210    }
211}
212
213/// 圆形布局
214fn compute_circular_layout(
215    nodes: &[NodeIndex],
216    options: &SvgOptions,
217) -> HashMap<NodeIndex, (f64, f64)> {
218    let mut positions = HashMap::new();
219    let n = nodes.len();
220    let center_x = options.width as f64 / 2.0;
221    let center_y = options.height as f64 / 2.0;
222    let radius = (options.width.min(options.height) as f64 / 2.0) * 0.8;
223
224    for (i, &node) in nodes.iter().enumerate() {
225        let angle = 2.0 * std::f64::consts::PI * (i as f64) / (n as f64);
226        let x = center_x + radius * angle.cos();
227        let y = center_y + radius * angle.sin();
228        positions.insert(node, (x, y));
229    }
230
231    positions
232}
233
234/// 力导向布局(简化版)
235fn compute_force_directed_layout<T, E>(
236    graph: &Graph<T, E>,
237    nodes: &[NodeIndex],
238    options: &SvgOptions,
239) -> HashMap<NodeIndex, (f64, f64)> {
240    let mut positions = HashMap::new();
241    let n = nodes.len();
242    let center_x = options.width as f64 / 2.0;
243    let center_y = options.height as f64 / 2.0;
244
245    // 初始随机位置
246    use std::collections::hash_map::DefaultHasher;
247    use std::hash::{Hash, Hasher};
248
249    for &node in nodes.iter() {
250        let mut hasher = DefaultHasher::new();
251        node.hash(&mut hasher);
252        let seed = hasher.finish() as f64;
253        let angle = seed * 0.001;
254        let radius = 50.0 + ((seed as u64) % 200) as f64;
255        let x = center_x + radius * angle.cos();
256        let y = center_y + radius * angle.sin();
257        positions.insert(node, (x, y));
258    }
259
260    // 迭代优化(简化版力导向)
261    let iterations = 50;
262    let repulsion = 1000.0;
263    let attraction = 0.01;
264    let damping = 0.85;
265
266    let mut velocities: HashMap<NodeIndex, (f64, f64)> =
267        nodes.iter().map(|&n| (n, (0.0, 0.0))).collect();
268
269    for _ in 0..iterations {
270        let mut forces: HashMap<NodeIndex, (f64, f64)> =
271            nodes.iter().map(|&n| (n, (0.0, 0.0))).collect();
272
273        // 斥力(节点之间)
274        for i in 0..n {
275            for j in (i + 1)..n {
276                let ni = nodes[i];
277                let nj = nodes[j];
278                let (xi, yi) = positions[&ni];
279                let (xj, yj) = positions[&nj];
280
281                let dx = xi - xj;
282                let dy = yi - yj;
283                let dist = (dx * dx + dy * dy).sqrt().max(1.0);
284
285                let force = repulsion / (dist * dist);
286                let fx = force * dx / dist;
287                let fy = force * dy / dist;
288
289                let (fix, fiy) = forces.get_mut(&ni).unwrap();
290                *fix += fx;
291                *fiy += fy;
292
293                let (fjx, fjy) = forces.get_mut(&nj).unwrap();
294                *fjx -= fx;
295                *fjy -= fy;
296            }
297        }
298
299        // 引力(边连接的节点)
300        for edge in graph.edges() {
301            let src = edge.source();
302            let tgt = edge.target();
303            if positions.contains_key(&src) && positions.contains_key(&tgt) {
304                let (xs, ys) = positions[&src];
305                let (xt, yt) = positions[&tgt];
306
307                let dx = xt - xs;
308                let dy = yt - ys;
309                let dist = (dx * dx + dy * dy).sqrt().max(1.0);
310
311                let force = attraction * dist;
312                let fx = force * dx / dist;
313                let fy = force * dy / dist;
314
315                let (fsx, fsy) = forces.get_mut(&src).unwrap();
316                *fsx += fx;
317                *fsy += fy;
318
319                let (ftx, fty) = forces.get_mut(&tgt).unwrap();
320                *ftx -= fx;
321                *fty -= fy;
322            }
323        }
324
325        // 向中心引力
326        for &node in nodes {
327            let (x, y) = positions[&node];
328            let dx = center_x - x;
329            let dy = center_y - y;
330            let (fx, fy) = forces.get_mut(&node).unwrap();
331            *fx += dx * 0.001;
332            *fy += dy * 0.001;
333        }
334
335        // 更新位置
336        for &node in nodes {
337            let (fx, fy) = forces[&node];
338            let (vx, vy) = velocities.get_mut(&node).unwrap();
339            *vx = (*vx + fx) * damping;
340            *vy = (*vy + fy) * damping;
341
342            let (x, y) = positions.get_mut(&node).unwrap();
343            *x += *vx;
344            *y += *vy;
345
346            // 边界限制
347            let margin = options.node_radius + 5.0;
348            *x = (*x).max(margin).min(options.width as f64 - margin);
349            *y = (*y).max(margin).min(options.height as f64 - margin);
350        }
351    }
352
353    positions
354}
355
356/// 层次布局(简化版,按拓扑排序)
357fn compute_hierarchical_layout<T, E: Clone>(
358    graph: &Graph<T, E>,
359    options: &SvgOptions,
360) -> HashMap<NodeIndex, (f64, f64)> {
361    use crate::algorithms::traversal::topological_sort;
362
363    let mut positions = HashMap::new();
364    let nodes_result = topological_sort(graph);
365
366    // 如果有环或错误,回退到圆形布局
367    let nodes = match nodes_result {
368        Ok(n) => n,
369        Err(_) => {
370            return compute_circular_layout(
371                &graph.nodes().map(|n| n.index()).collect::<Vec<_>>(),
372                options,
373            )
374        }
375    };
376
377    if nodes.is_empty() {
378        return compute_circular_layout(
379            &graph.nodes().map(|n| n.index()).collect::<Vec<_>>(),
380            options,
381        );
382    }
383
384    let n = nodes.len();
385    let levels: Vec<Vec<NodeIndex>> = vec![nodes]; // 简化:单层
386    let num_levels = levels.len();
387
388    let level_height = options.height as f64 / (num_levels as f64 + 1.0);
389    let node_spacing = options.width as f64 / (n as f64 + 1.0);
390
391    for (level_idx, level_nodes) in levels.iter().enumerate() {
392        let y = level_height * (level_idx as f64 + 1.0);
393        for (node_idx, &node) in level_nodes.iter().enumerate() {
394            let x = node_spacing * (node_idx as f64 + 1.0);
395            positions.insert(node, (x, y));
396        }
397    }
398
399    positions
400}
401
402/// 绘制箭头
403fn draw_arrow(output: &mut String, x1: f64, y1: f64, x2: f64, y2: f64, options: &SvgOptions) {
404    let arrow_size = 8.0;
405    let angle = (y2 - y1).atan2(x2 - x1);
406    let arrow_angle = std::f64::consts::FRAC_PI_4;
407
408    // 计算箭头尖端位置(在节点边缘)
409    let dist = ((x2 - x1).powi(2) + (y2 - y1).powi(2)).sqrt();
410    let stop_dist = dist - options.node_radius;
411
412    if stop_dist < 0.0 {
413        return; // 节点重叠,不绘制箭头
414    }
415
416    let x1_adj = x1 + (x2 - x1) * (stop_dist / dist);
417    let y1_adj = y1 + (y2 - y1) * (stop_dist / dist);
418
419    // 箭头左翼
420    let left_angle = angle + arrow_angle;
421    let x_left = x1_adj - arrow_size * left_angle.cos();
422    let y_left = y1_adj - arrow_size * left_angle.sin();
423
424    // 箭头右翼
425    let right_angle = angle - arrow_angle;
426    let x_right = x1_adj - arrow_size * right_angle.cos();
427    let y_right = y1_adj - arrow_size * right_angle.sin();
428
429    output.push_str(&format!(
430        r#"<polygon points="{},{} {},{} {},{}" fill="{}" stroke="none"/>"#,
431        x1_adj, y1_adj, x_left, y_left, x_right, y_right, options.edge_color
432    ));
433    output.push('\n');
434}
435
436/// 转义 XML 特殊字符
437fn escape_xml(s: &str) -> String {
438    s.replace('&', "&amp;")
439        .replace('<', "&lt;")
440        .replace('>', "&gt;")
441        .replace('"', "&quot;")
442        .replace('\'', "&apos;")
443}
444
445/// 将 SVG 字符串写入文件
446pub fn write_svg_to_file(svg: &str, path: &str) -> std::io::Result<()> {
447    std::fs::write(path, svg)
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453    use crate::graph::traits::GraphOps;
454    use crate::graph::Graph;
455
456    #[test]
457    fn test_svg_export_basic() {
458        let mut graph: Graph<String, f64> = Graph::directed();
459        let a = graph.add_node("A".to_string()).unwrap();
460        let b = graph.add_node("B".to_string()).unwrap();
461        graph.add_edge(a, b, 1.0).unwrap();
462
463        let svg = to_svg(&graph);
464        assert!(svg.contains("<svg"));
465        assert!(svg.contains("</svg>"));
466        assert!(svg.contains("<circle"));
467        assert!(svg.contains("<line"));
468    }
469
470    #[test]
471    fn test_svg_options() {
472        let mut graph: Graph<String, f64> = Graph::directed();
473        let a = graph.add_node("A".to_string()).unwrap();
474        let b = graph.add_node("B".to_string()).unwrap();
475        graph.add_edge(a, b, 1.0).unwrap();
476
477        let options = SvgOptions::new()
478            .with_size(400, 300)
479            .with_node_radius(15.0)
480            .with_labels(false);
481
482        let svg = to_svg_with_options(&graph, &options);
483        assert!(svg.contains(r#"width="400""#));
484        assert!(svg.contains(r#"height="300""#));
485    }
486
487    #[test]
488    fn test_svg_empty_graph() {
489        let graph: Graph<String, f64> = Graph::directed();
490        let svg = to_svg(&graph);
491        assert!(svg.contains("<svg"));
492        assert!(svg.contains("</svg>"));
493    }
494}