Skip to main content

physdes/
global_router.rs

1//! Global router for Steiner tree-based routing with keepout avoidance.
2//!
3//! Provides data structures and algorithms for constructing rectilinear
4//! Steiner routing trees with support for simple routing, Steiner point
5//! insertion, wirelength-constrained routing, and rectangular keepout
6//! avoidance. Includes an SVG visualizer for result inspection.
7
8use std::collections::HashMap;
9use std::fmt;
10
11use crate::generic::{Contain, MinDist};
12use crate::interval::{Hull, Interval};
13use crate::point::Point;
14
15/// Type of a routing node.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum NodeType {
18    Steiner,
19    Terminal,
20    Source,
21}
22
23impl fmt::Display for NodeType {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        match self {
26            NodeType::Steiner => write!(f, "Steiner"),
27            NodeType::Terminal => write!(f, "Terminal"),
28            NodeType::Source => write!(f, "Source"),
29        }
30    }
31}
32
33/// A node in the routing tree.
34#[derive(Debug, Clone)]
35pub struct RoutingNode {
36    /// Unique identifier for this node
37    pub id: String,
38    /// Type of this node (Source, Steiner, or Terminal)
39    pub node_type: NodeType,
40    /// Position of this node in the layout
41    pub pt: Point<i32, i32>,
42    /// Indices of child nodes in the tree
43    pub children: Vec<usize>,
44    /// Index of the parent node, if any
45    pub parent: Option<usize>,
46    /// Load capacitance at this node
47    pub capacitance: f64,
48    /// Signal delay at this node
49    pub delay: f64,
50    /// Path length from source to this node
51    pub path_length: i32,
52}
53
54impl RoutingNode {
55    /// Creates a new routing node with the given id, type, and position.
56    pub fn new(id: &str, node_type: NodeType, pt: Point<i32, i32>) -> Self {
57        RoutingNode {
58            id: id.to_string(),
59            node_type,
60            pt,
61            children: Vec::new(),
62            parent: None,
63            capacitance: 0.0,
64            delay: 0.0,
65            path_length: 0,
66        }
67    }
68
69    /// Computes the Manhattan distance to another routing node.
70    ///
71    /// $$d = |x_1 - x_2| + |y_1 - y_2|$$
72    pub fn manhattan_distance(&self, other: &RoutingNode) -> i32 {
73        self.pt.min_dist_with(&other.pt) as i32
74    }
75}
76
77/// A rectilinear Steiner routing tree with support for keepout avoidance.
78pub struct GlobalRoutingTree {
79    nodes: Vec<RoutingNode>,
80    node_map: HashMap<String, usize>,
81    source_idx: usize,
82    next_steiner_id: i32,
83    next_terminal_id: i32,
84    /// The worst-case (longest) wirelength among all source-to-terminal paths
85    pub worst_wirelength: i32,
86}
87
88impl GlobalRoutingTree {
89    pub fn new(source_position: Point<i32, i32>) -> Self {
90        let source = RoutingNode::new("source", NodeType::Source, source_position);
91        let mut nodes = Vec::new();
92        let mut node_map = HashMap::new();
93        node_map.insert("source".to_string(), 0usize);
94        nodes.push(source);
95        GlobalRoutingTree {
96            nodes,
97            node_map,
98            source_idx: 0,
99            next_steiner_id: 1,
100            next_terminal_id: 1,
101            worst_wirelength: 0,
102        }
103    }
104
105    /// Returns a shared reference to the source node.
106    pub fn get_source(&self) -> &RoutingNode {
107        &self.nodes[self.source_idx]
108    }
109
110    /// Returns a mutable reference to the source node.
111    pub fn get_source_mut(&mut self) -> &mut RoutingNode {
112        &mut self.nodes[self.source_idx]
113    }
114
115    fn add_node(&mut self, node: RoutingNode) -> usize {
116        let idx = self.nodes.len();
117        self.node_map.insert(node.id.clone(), idx);
118        self.nodes.push(node);
119        idx
120    }
121
122    fn _find_nearest_node(&self, point: Point<i32, i32>, exclude_id: Option<&str>) -> usize {
123        if self.nodes.len() <= 1 {
124            return self.source_idx;
125        }
126        let mut nearest = self.source_idx;
127        let mut min_dist = i32::MAX;
128        for (idx, node) in self.nodes.iter().enumerate() {
129            if let Some(ex) = exclude_id {
130                if node.id == ex {
131                    continue;
132                }
133            }
134            let dist = node.pt.min_dist_with(&point) as i32;
135            if dist < min_dist {
136                min_dist = dist;
137                nearest = idx;
138            }
139        }
140        nearest
141    }
142
143    pub fn insert_steiner_node(
144        &mut self,
145        point: Point<i32, i32>,
146        parent_id: Option<&str>,
147    ) -> String {
148        let id = format!("steiner_{}", self.next_steiner_id);
149        self.next_steiner_id += 1;
150        let idx = self.add_node(RoutingNode::new(&id, NodeType::Steiner, point));
151
152        let parent_idx = match parent_id {
153            Some(pid) => *self.node_map.get(pid).expect("Parent node not found"),
154            None => self.source_idx,
155        };
156        self.nodes[idx].parent = Some(parent_idx);
157        self.nodes[parent_idx].children.push(idx);
158        id
159    }
160
161    pub fn insert_terminal_node(
162        &mut self,
163        point: Point<i32, i32>,
164        parent_id: Option<&str>,
165    ) -> String {
166        let id = format!("terminal_{}", self.next_terminal_id);
167        self.next_terminal_id += 1;
168
169        let parent_idx = match parent_id {
170            Some(pid) => *self.node_map.get(pid).expect("Parent node not found"),
171            None => self._find_nearest_node(point, None),
172        };
173
174        let idx = self.add_node(RoutingNode::new(&id, NodeType::Terminal, point));
175        self.nodes[idx].parent = Some(parent_idx);
176        self.nodes[parent_idx].children.push(idx);
177        id
178    }
179
180    /// Insert a new node on an existing branch between two nodes (Python `insert_node_on_branch`).
181    #[allow(clippy::manual_contains)]
182    pub fn insert_node_on_branch(
183        &mut self,
184        node_type: NodeType,
185        point: Point<i32, i32>,
186        branch_start_id: &str,
187        branch_end_id: &str,
188    ) -> String {
189        let start_idx = *self
190            .node_map
191            .get(branch_start_id)
192            .expect("Branch start node not found");
193        let end_idx = *self
194            .node_map
195            .get(branch_end_id)
196            .expect("Branch end node not found");
197
198        let is_child = self.nodes[start_idx].children.iter().any(|&c| c == end_idx);
199        assert!(
200            self.nodes[end_idx].parent == Some(start_idx) || is_child,
201            "branch_end is not a direct child of branch_start"
202        );
203
204        let id = match node_type {
205            NodeType::Steiner => {
206                let s = format!("steiner_{}", self.next_steiner_id);
207                self.next_steiner_id += 1;
208                s
209            }
210            NodeType::Terminal => {
211                let s = format!("terminal_{}", self.next_terminal_id);
212                self.next_terminal_id += 1;
213                s
214            }
215            _ => panic!("Node type must be Steiner or Terminal"),
216        };
217        let new_idx = self.add_node(RoutingNode::new(&id, node_type, point));
218
219        // Rewire: start -> new -> end
220        self.nodes[start_idx].children.retain(|c| *c != end_idx);
221        self.nodes[end_idx].parent = None;
222
223        self.nodes[new_idx].parent = Some(start_idx);
224        self.nodes[start_idx].children.push(new_idx);
225
226        self.nodes[end_idx].parent = Some(new_idx);
227        self.nodes[new_idx].children.push(end_idx);
228
229        id
230    }
231
232    /// Find the nearest insertion point for a terminal, avoiding keepouts.
233    /// Returns `(parent_node_idx, nearest_node_idx)` where parent_node is `Some` when a Steiner
234    /// point needs to be inserted on the branch between parent and nearest.
235    fn _find_insertion_point(
236        &self,
237        point: Point<i32, i32>,
238        allowed_wirelength: i32,
239        keepouts: &Option<Vec<Point<Interval<i32>, Interval<i32>>>>,
240    ) -> (Option<usize>, usize) {
241        let mut nearest_node = self.source_idx;
242        let mut parent_node: Option<usize> = None;
243        let mut min_distance = self.worst_wirelength.max(1);
244        let mut valid_found = false;
245
246        // Stack-based DFS
247        let mut stack = vec![self.source_idx];
248
249        while let Some(node_idx) = stack.pop() {
250            let child_count = self.nodes[node_idx].children.len();
251            // Push children in order (reverse for DFS order)
252            for ci in (0..child_count).rev() {
253                let child_idx = self.nodes[node_idx].children[ci];
254                let possible_path = self.nodes[node_idx].pt.hull_with(&self.nodes[child_idx].pt);
255                let distance = possible_path.min_dist_with(&point) as i32;
256                let nearest_pt = possible_path.nearest_to(&point);
257
258                // Check keepouts
259                if let Some(ref kos) = *keepouts {
260                    let mut blocked = false;
261                    let path1 = nearest_pt.hull_with(&point);
262                    let path2 = nearest_pt.hull_with(&self.nodes[node_idx].pt);
263                    let path3 = nearest_pt.hull_with(&self.nodes[child_idx].pt);
264                    for ko in kos {
265                        if ko.contains(&nearest_pt)
266                            || ko.blocks(&path1)
267                            || ko.blocks(&path2)
268                            || ko.blocks(&path3)
269                        {
270                            blocked = true;
271                            break;
272                        }
273                    }
274                    if blocked {
275                        continue;
276                    }
277                }
278
279                let path_length = self.nodes[node_idx].path_length
280                    + self.nodes[node_idx].pt.min_dist_with(&nearest_pt) as i32
281                    + distance;
282
283                let mut update = false;
284                if path_length <= allowed_wirelength {
285                    if valid_found {
286                        if distance < min_distance {
287                            update = true;
288                        }
289                    } else {
290                        valid_found = true;
291                        update = true;
292                    }
293                } else if !valid_found
294                    && path_length <= self.worst_wirelength
295                    && distance < min_distance
296                {
297                    update = true;
298                }
299
300                if update {
301                    min_distance = distance;
302                    if nearest_pt == self.nodes[node_idx].pt {
303                        nearest_node = node_idx;
304                        parent_node = None;
305                    } else if nearest_pt == self.nodes[child_idx].pt {
306                        nearest_node = child_idx;
307                        parent_node = None;
308                    } else {
309                        nearest_node = child_idx;
310                        parent_node = Some(node_idx);
311                    }
312                }
313
314                stack.push(child_idx);
315            }
316        }
317
318        (parent_node, nearest_node)
319    }
320
321    fn _insert_terminal_impl(
322        &mut self,
323        point: Point<i32, i32>,
324        allowed_wirelength: i32,
325        keepouts: Option<Vec<Point<Interval<i32>, Interval<i32>>>>,
326    ) {
327        let terminal_id = format!("terminal_{}", self.next_terminal_id);
328        self.next_terminal_id += 1;
329        let terminal_idx = self.add_node(RoutingNode::new(&terminal_id, NodeType::Terminal, point));
330
331        let (parent_node, nearest_node) =
332            self._find_insertion_point(point, allowed_wirelength, &keepouts);
333
334        let nearest_idx = nearest_node;
335        match parent_node {
336            None => {
337                self.nodes[terminal_idx].parent = Some(nearest_idx);
338                self.nodes[nearest_idx].children.push(terminal_idx);
339                let dist = self.nodes[nearest_idx].pt.min_dist_with(&point) as i32;
340                self.nodes[terminal_idx].path_length = self.nodes[nearest_idx].path_length + dist;
341            }
342            Some(parent_idx) => {
343                let steiner_id = format!("steiner_{}", self.next_steiner_id);
344                self.next_steiner_id += 1;
345
346                let possible_path = self.nodes[parent_idx]
347                    .pt
348                    .hull_with(&self.nodes[nearest_idx].pt);
349                let nearest_pt = possible_path.nearest_to(&point);
350                let steiner_idx =
351                    self.add_node(RoutingNode::new(&steiner_id, NodeType::Steiner, nearest_pt));
352
353                // Rewire: parent -> nearest  becomes  parent -> steiner -> nearest
354                self.nodes[parent_idx]
355                    .children
356                    .retain(|c| *c != nearest_idx);
357                self.nodes[nearest_idx].parent = None;
358
359                self.nodes[steiner_idx].parent = Some(parent_idx);
360                self.nodes[parent_idx].children.push(steiner_idx);
361
362                let dist_ps = self.nodes[parent_idx].pt.min_dist_with(&nearest_pt) as i32;
363                self.nodes[steiner_idx].path_length = self.nodes[parent_idx].path_length + dist_ps;
364
365                self.nodes[nearest_idx].parent = Some(steiner_idx);
366                self.nodes[steiner_idx].children.push(nearest_idx);
367
368                self.nodes[terminal_idx].parent = Some(steiner_idx);
369                self.nodes[steiner_idx].children.push(terminal_idx);
370
371                let dist_st = nearest_pt.min_dist_with(&point) as i32;
372                self.nodes[terminal_idx].path_length =
373                    self.nodes[steiner_idx].path_length + dist_st;
374            }
375        }
376    }
377
378    pub fn insert_terminal_with_steiner(
379        &mut self,
380        point: Point<i32, i32>,
381        keepouts: Option<Vec<Point<Interval<i32>, Interval<i32>>>>,
382    ) {
383        self._insert_terminal_impl(point, i32::MAX, keepouts);
384    }
385
386    pub fn insert_terminal_with_constraints(
387        &mut self,
388        point: Point<i32, i32>,
389        allowed_wirelength: i32,
390        keepouts: Option<Vec<Point<Interval<i32>, Interval<i32>>>>,
391    ) {
392        self._insert_terminal_impl(point, allowed_wirelength, keepouts);
393    }
394
395    /// Calculates the total wirelength of the entire routing tree.
396    ///
397    /// $$L = \sum_{\text{node}} \text{Manhattan}(\text{node},\; \text{parent(node)})$$
398    pub fn calculate_total_wirelength(&self) -> i32 {
399        let mut total = 0;
400        for node in &self.nodes {
401            if let Some(parent_idx) = node.parent {
402                total += self.nodes[parent_idx].manhattan_distance(node);
403            }
404        }
405        total
406    }
407
408    /// Calculates the worst-case (maximum) source-to-terminal wirelength.
409    ///
410    /// $$W = \max_{\text{leaf}} \sum_{\text{path(source, leaf)}} \text{edge\_length}$$
411    pub fn calculate_worst_wirelength(&self) -> i32 {
412        fn traverse(tree: &GlobalRoutingTree, idx: usize) -> i32 {
413            let node = &tree.nodes[idx];
414            let mut worst = 0;
415            for &child in &node.children {
416                let child_len = node.manhattan_distance(&tree.nodes[child]);
417                let child_path = traverse(tree, child);
418                worst = worst.max(child_len + child_path);
419            }
420            worst
421        }
422        traverse(self, self.source_idx)
423    }
424
425    /// Finds the path from a node back to the source.
426    ///
427    /// Returns the nodes along the path in order from source to the target node.
428    pub fn find_path_to_source(&self, node_id: &str) -> Vec<&RoutingNode> {
429        let mut idx = *self.node_map.get(node_id).expect("Node not found");
430        let mut path = Vec::new();
431        loop {
432            path.push(&self.nodes[idx]);
433            match self.nodes[idx].parent {
434                Some(p) => idx = p,
435                None => break,
436            }
437        }
438        path.reverse();
439        path
440    }
441
442    /// Returns all terminal nodes in the routing tree.
443    pub fn get_all_terminals(&self) -> Vec<&RoutingNode> {
444        self.nodes
445            .iter()
446            .filter(|n| n.node_type == NodeType::Terminal)
447            .collect()
448    }
449
450    /// Returns all Steiner nodes in the routing tree.
451    pub fn get_all_steiner_nodes(&self) -> Vec<&RoutingNode> {
452        self.nodes
453            .iter()
454            .filter(|n| n.node_type == NodeType::Steiner)
455            .collect()
456    }
457
458    /// Returns a formatted string representation of the tree structure.
459    pub fn get_tree_structure(&self) -> String {
460        fn fmt_node(tree: &GlobalRoutingTree, idx: usize, level: usize) -> String {
461            let node = &tree.nodes[idx];
462            let mut s = format!(
463                "{}{}({}, {})",
464                "  ".repeat(level),
465                node.node_type,
466                node.id,
467                node.pt
468            );
469            s.push('\n');
470            for &child in &node.children {
471                s.push_str(&fmt_node(tree, child, level + 1));
472            }
473            s
474        }
475        fmt_node(self, self.source_idx, 0)
476    }
477
478    pub fn visualize_tree(&self) {
479        println!("Global Routing Tree Structure:");
480        println!("================================");
481        print!("{}", self.get_tree_structure());
482        println!("Total wirelength: {}", self.calculate_total_wirelength());
483        println!("Total nodes: {}", self.nodes.len());
484        println!("Terminals: {}", self.get_all_terminals().len());
485        println!("Steiner points: {}", self.get_all_steiner_nodes().len());
486    }
487
488    /// Removes redundant Steiner points that have only one child.
489    ///
490    /// After optimization, the remaining Steiner points have at least two
491    /// children and are topologically significant.
492    pub fn optimize_steiner_points(&mut self) {
493        let to_remove: Vec<usize> = self
494            .nodes
495            .iter()
496            .enumerate()
497            .filter(|(_, n)| {
498                n.node_type == NodeType::Steiner && n.children.len() == 1 && n.parent.is_some()
499            })
500            .map(|(i, _)| i)
501            .collect();
502
503        for &idx in to_remove.iter().rev() {
504            let parent = self.nodes[idx].parent;
505            let child = self.nodes[idx].children[0];
506            if let Some(p) = parent {
507                self.nodes[p].children.retain(|&c| c != idx);
508                self.nodes[p].children.push(child);
509            }
510            self.nodes[child].parent = parent;
511            self.node_map.remove(&self.nodes[idx].id);
512        }
513        self.nodes.retain(|n| self.node_map.contains_key(&n.id));
514        self.remap_indices();
515    }
516
517    fn remap_indices(&mut self) {
518        let old_indices: Vec<usize> = (0..self.nodes.len()).collect();
519        let mut new_map = HashMap::new();
520        let mut new_nodes = Vec::new();
521        for node in self.nodes.drain(..) {
522            let new_idx = new_nodes.len();
523            new_map.insert(node.id.clone(), new_idx);
524            new_nodes.push(node);
525        }
526        self.nodes = new_nodes;
527        for node in &mut self.nodes {
528            if let Some(p) = node.parent {
529                if let Some(&old_p) = old_indices.get(p) {
530                    node.parent = Some(old_p);
531                } else {
532                    node.parent = None;
533                }
534            }
535            node.children = node
536                .children
537                .iter()
538                .filter_map(|c| old_indices.get(*c).copied())
539                .collect();
540        }
541        self.node_map = new_map;
542    }
543
544    /// Generate an SVG visualization of the routing tree.
545    pub fn to_svg(
546        &self,
547        keepouts: Option<&Vec<Point<Interval<i32>, Interval<i32>>>>,
548        width: u32,
549        height: u32,
550        margin: u32,
551    ) -> String {
552        if self.nodes.is_empty() {
553            return "<svg></svg>".to_string();
554        }
555
556        let min_x = self.nodes.iter().map(|n| n.pt.xcoord).min().unwrap();
557        let max_x = self.nodes.iter().map(|n| n.pt.xcoord).max().unwrap();
558        let min_y = self.nodes.iter().map(|n| n.pt.ycoord).min().unwrap();
559        let max_y = self.nodes.iter().map(|n| n.pt.ycoord).max().unwrap();
560
561        let range_x = (max_x - min_x).max(1) as f64;
562        let range_y = (max_y - min_y).max(1) as f64;
563
564        let w = (width as f64) - 2.0 * (margin as f64);
565        let h = (height as f64) - 2.0 * (margin as f64);
566        let scale = (w / range_x).min(h / range_y);
567
568        let sx = |x: i32| margin as f64 + (x - min_x) as f64 * scale;
569        let sy = |y: i32| margin as f64 + (y - min_y) as f64 * scale;
570
571        let mut svg = String::new();
572        svg.push_str(&format!(
573            r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">"#,
574            width, height
575        ));
576        svg.push_str(r#"<rect width="100%" height="100%" fill="white"/>"#);
577
578        // Arrowhead marker
579        svg.push_str(
580            r#"<defs><marker id="ah" markerWidth="10" markerHeight="7" refX="9" refY="3.5" orient="auto">"#,
581        );
582        svg.push_str(r#"<polygon points="0 0, 10 3.5, 0 7" fill="black"/>"#);
583        svg.push_str("</marker></defs>");
584
585        // Draw connections
586        fn draw_conn(
587            svg: &mut String,
588            tree: &GlobalRoutingTree,
589            idx: usize,
590            sx: &dyn Fn(i32) -> f64,
591            sy: &dyn Fn(i32) -> f64,
592        ) {
593            let node = &tree.nodes[idx];
594            for &child in &node.children {
595                let cnode = &tree.nodes[child];
596                let (x1, y1) = (sx(node.pt.xcoord), sy(node.pt.ycoord));
597                let (x2, y2) = (sx(cnode.pt.xcoord), sy(cnode.pt.ycoord));
598                svg.push_str(&format!(
599                    r#"<line x1="{}" y1="{}" x2="{}" y2="{}" stroke="black" stroke-width="2" marker-end="url(#ah)"/>"#,
600                    x1, y1, x2, y2
601                ));
602            }
603            for &child in &node.children {
604                draw_conn(svg, tree, child, sx, sy);
605            }
606        }
607        draw_conn(&mut svg, self, self.source_idx, &sx, &sy);
608
609        // Draw keepouts
610        if let Some(kos) = keepouts {
611            for ko in kos {
612                let x1 = sx(ko.xcoord.lb);
613                let y1 = sy(ko.ycoord.lb);
614                let x2 = sx(ko.xcoord.ub);
615                let y2 = sy(ko.ycoord.ub);
616                let rw = (x2 - x1).abs();
617                let rh = (y2 - y1).abs();
618                svg.push_str(&format!(
619                    r#"<rect x="{}" y="{}" width="{}" height="{}" fill="orange" stroke="black" stroke-width="1"/>"#,
620                    x1.min(x2), y1.min(y2), rw, rh
621                ));
622            }
623        }
624
625        // Draw nodes
626        for node in &self.nodes {
627            let (x_pos, y_pos) = (sx(node.pt.xcoord), sy(node.pt.ycoord));
628            let label = match node.node_type {
629                NodeType::Source => "S".to_string(),
630                NodeType::Steiner => {
631                    format!("S{}", node.id.strip_prefix("steiner_").unwrap_or("t"))
632                }
633                NodeType::Terminal => {
634                    format!("T{}", node.id.strip_prefix("terminal_").unwrap_or(""))
635                }
636            };
637            let (color, radius) = match node.node_type {
638                NodeType::Source => ("red", 8u32),
639                NodeType::Steiner => ("blue", 6u32),
640                NodeType::Terminal => ("green", 6u32),
641            };
642            svg.push_str(&format!(
643                r#"<circle cx="{}" cy="{}" r="{}" fill="{}" stroke="black" stroke-width="1"/>"#,
644                x_pos, y_pos, radius, color
645            ));
646            svg.push_str(&format!(
647                r#"<text x="{}" y="{}" font-family="Arial" font-size="10" fill="black">{}</text>"#,
648                x_pos + radius as f64 + 2.0,
649                y_pos + 4.0,
650                label
651            ));
652            svg.push_str(&format!(
653                r#"<text x="{}" y="{}" font-family="Arial" font-size="8" fill="gray" text-anchor="middle">({},{})</text>"#,
654                x_pos,
655                y_pos - radius as f64 - 5.0,
656                node.pt.xcoord,
657                node.pt.ycoord
658            ));
659        }
660
661        // Legend
662        let ly = 20u32;
663        svg.push_str(&format!(
664            r#"<text x="20" y="{}" font-family="Arial" font-size="12" font-weight="bold">Legend:</text>"#,
665            ly
666        ));
667        let items = [
668            ("Source", "red", 20, ly + 20),
669            ("Steiner", "blue", 20, ly + 40),
670            ("Terminal", "green", 20, ly + 60),
671        ];
672        for (text, color, lx, ly) in &items {
673            svg.push_str(&format!(
674                r#"<circle cx="{}" cy="{}" r="4" fill="{}" stroke="black"/>"#,
675                lx,
676                ly - 4,
677                color
678            ));
679            svg.push_str(&format!(
680                r#"<text x="{}" y="{}" font-family="Arial" font-size="10">{}</text>"#,
681                lx + 10,
682                ly,
683                text
684            ));
685        }
686
687        // Statistics
688        let sy2 = ly + 90;
689        svg.push_str(&format!(
690            r#"<text x="20" y="{}" font-family="Arial" font-size="10" font-weight="bold">Statistics:</text>"#,
691            sy2
692        ));
693        svg.push_str(&format!(
694            r#"<text x="20" y="{}" font-family="Arial" font-size="9">Total Nodes: {}</text>"#,
695            sy2 + 15,
696            self.nodes.len()
697        ));
698        svg.push_str(&format!(
699            r#"<text x="20" y="{}" font-family="Arial" font-size="9">Terminals: {}</text>"#,
700            sy2 + 30,
701            self.get_all_terminals().len()
702        ));
703        svg.push_str(&format!(
704            r#"<text x="20" y="{}" font-family="Arial" font-size="9">Steiner: {}</text>"#,
705            sy2 + 45,
706            self.get_all_steiner_nodes().len()
707        ));
708        svg.push_str(&format!(
709            r#"<text x="20" y="{}" font-family="Arial" font-size="9">Wirelength: {}</text>"#,
710            sy2 + 60,
711            self.calculate_total_wirelength()
712        ));
713
714        svg.push_str("</svg>");
715        svg
716    }
717
718    /// Save SVG to a file.
719    pub fn save_svg(
720        &self,
721        keepouts: Option<&Vec<Point<Interval<i32>, Interval<i32>>>>,
722        filename: &str,
723        width: u32,
724        height: u32,
725    ) {
726        let svg = self.to_svg(keepouts, width, height, 50);
727        std::fs::write(filename, svg).expect("Failed to write SVG file");
728        println!("Saved SVG to {}", filename);
729    }
730}
731
732/// High-level global router that constructs a routing tree from a source
733/// and a set of terminal points, with optional keepout avoidance.
734pub struct GlobalRouter {
735    terminal_positions: Vec<Point<i32, i32>>,
736    tree: GlobalRoutingTree,
737    worst_wirelength: i32,
738    keepouts: Option<Vec<Point<Interval<i32>, Interval<i32>>>>,
739}
740
741impl GlobalRouter {
742    pub fn new(
743        source_pos: Point<i32, i32>,
744        terminal_positions: Vec<Point<i32, i32>>,
745        keepout_regions: Option<Vec<Point<Interval<i32>, Interval<i32>>>>,
746    ) -> Self {
747        let mut sorted = terminal_positions.clone();
748        sorted.sort_by(|a, b| {
749            let da = source_pos.min_dist_with(a) as i32;
750            let db = source_pos.min_dist_with(b) as i32;
751            da.cmp(&db)
752        });
753
754        let worst = if sorted.is_empty() {
755            0
756        } else {
757            source_pos.min_dist_with(&sorted[sorted.len() - 1]) as i32
758        };
759
760        GlobalRouter {
761            terminal_positions: sorted,
762            tree: GlobalRoutingTree::new(source_pos),
763            worst_wirelength: worst,
764            keepouts: keepout_regions,
765        }
766    }
767
768    /// Routes terminals by directly connecting each terminal to the nearest
769    /// node in the existing tree (simple nearest-neighbor heuristic).
770    pub fn route_simple(&mut self) {
771        for &terminal in &self.terminal_positions {
772            self.tree.insert_terminal_node(terminal, None);
773        }
774    }
775
776    /// Routes terminals with Steiner point insertion to reduce total
777    /// wirelength while avoiding keepout regions.
778    pub fn route_with_steiners(&mut self) {
779        self.tree.worst_wirelength = self.worst_wirelength;
780        for &terminal in &self.terminal_positions {
781            self.tree
782                .insert_terminal_with_steiner(terminal, self.keepouts.clone());
783        }
784    }
785
786    /// Routes terminals with Steiner points and wirelength constraints.
787    /// The `multiplier` scales the worst-case wirelength to determine the
788    /// allowed path length for each terminal.
789    pub fn route_with_constraints(&mut self, multiplier: f64) {
790        let allowed = (self.worst_wirelength as f64 * multiplier).round() as i32;
791        self.tree.worst_wirelength = self.worst_wirelength;
792        for &terminal in &self.terminal_positions {
793            self.tree
794                .insert_terminal_with_constraints(terminal, allowed, self.keepouts.clone());
795        }
796    }
797
798    /// Returns a reference to the constructed routing tree.
799    pub fn get_tree(&self) -> &GlobalRoutingTree {
800        &self.tree
801    }
802}
803
804#[cfg(test)]
805mod tests {
806    use super::*;
807
808    fn make_keepout(x1: i32, x2: i32, y1: i32, y2: i32) -> Point<Interval<i32>, Interval<i32>> {
809        let lo_x = x1.min(x2);
810        let hi_x = x1.max(x2);
811        let lo_y = y1.min(y2);
812        let hi_y = y1.max(y2);
813        Point::new(Interval::new(lo_x, hi_x), Interval::new(lo_y, hi_y))
814    }
815
816    #[test]
817    fn test_route_simple() {
818        let src = Point::new(0, 0);
819        let terminals = vec![Point::new(1, 1), Point::new(2, 2)];
820        let mut router = GlobalRouter::new(src, terminals, None);
821        router.route_simple();
822        assert_eq!(router.get_tree().calculate_total_wirelength(), 4);
823    }
824
825    #[test]
826    fn test_route_with_steiners() {
827        let src = Point::new(0, 0);
828        let terminals = vec![Point::new(1, 1), Point::new(2, 2)];
829        let mut router = GlobalRouter::new(src, terminals, None);
830        router.route_with_steiners();
831        assert_eq!(router.get_tree().calculate_total_wirelength(), 4);
832    }
833
834    #[test]
835    fn test_route_with_constraints() {
836        let src = Point::new(0, 0);
837        let terminals = vec![Point::new(1, 1), Point::new(2, 2)];
838        let mut router = GlobalRouter::new(src, terminals, None);
839        router.route_with_constraints(2.0);
840        assert_eq!(router.get_tree().calculate_total_wirelength(), 4);
841    }
842
843    #[test]
844    fn test_route_three_sinks_simple() {
845        let src = Point::new(0, 0);
846        let terminals = vec![Point::new(10, 0), Point::new(5, 10)];
847        let mut router = GlobalRouter::new(src, terminals, None);
848        router.route_simple();
849        let wl = router.get_tree().calculate_total_wirelength();
850        assert_eq!(wl, 25);
851    }
852
853    #[test]
854    fn test_route_with_keepout() {
855        let src = Point::new(0, 0);
856        let terminals = vec![Point::new(10, 0)];
857        let keepout = make_keepout(4, 6, -1, 1);
858        let mut router = GlobalRouter::new(src, terminals, Some(vec![keepout]));
859        router.route_with_steiners();
860        let wl = router.get_tree().calculate_total_wirelength();
861        // With keepout, the route should still complete
862        assert!(wl > 0);
863    }
864
865    #[test]
866    fn test_insert_steiner_and_terminal() {
867        let mut tree = GlobalRoutingTree::new(Point::new(0, 0));
868        let s1 = tree.insert_steiner_node(Point::new(1, 1), None);
869        let t1 = tree.insert_terminal_node(Point::new(2, 2), Some(&s1));
870        assert_eq!(tree.calculate_total_wirelength(), 4);
871        assert_eq!(tree.get_all_terminals().len(), 1);
872        assert_eq!(tree.get_all_steiner_nodes().len(), 1);
873        let path = tree.find_path_to_source(&t1);
874        assert_eq!(path.len(), 3);
875        assert_eq!(path[0].id, "source");
876        assert_eq!(path[1].id, s1);
877        assert_eq!(path[2].id, t1);
878    }
879
880    #[test]
881    fn test_insert_node_on_branch() {
882        let mut tree = GlobalRoutingTree::new(Point::new(0, 0));
883        let s1 = tree.insert_steiner_node(Point::new(1, 1), None);
884        let t1 = tree.insert_terminal_node(Point::new(2, 2), Some(&s1));
885        let new_id = tree.insert_node_on_branch(NodeType::Steiner, Point::new(1, 2), &s1, &t1);
886        // new_id should be steiner_2
887        assert_eq!(new_id, "steiner_2");
888        // Path length should still work
889        let path = tree.find_path_to_source(&t1);
890        assert_eq!(path.len(), 4);
891    }
892
893    #[test]
894    fn test_optimize_steiner_points() {
895        let mut tree = GlobalRoutingTree::new(Point::new(0, 0));
896        let s1 = tree.insert_steiner_node(Point::new(1, 1), None);
897        let _t1 = tree.insert_terminal_node(Point::new(2, 2), Some(&s1));
898        tree.optimize_steiner_points();
899        assert_eq!(tree.get_all_steiner_nodes().len(), 0);
900    }
901
902    #[test]
903    fn test_calculate_worst_wirelength() {
904        let mut tree = GlobalRoutingTree::new(Point::new(0, 0));
905        let s1 = tree.insert_steiner_node(Point::new(1, 1), None);
906        let _t1 = tree.insert_terminal_node(Point::new(2, 2), Some(&s1));
907        assert_eq!(tree.calculate_worst_wirelength(), 4);
908    }
909
910    #[test]
911    fn test_to_svg_contains_elements() {
912        let mut tree = GlobalRoutingTree::new(Point::new(0, 0));
913        let _s1 = tree.insert_steiner_node(Point::new(1, 1), None);
914        let _t1 = tree.insert_terminal_node(Point::new(2, 2), Some("steiner_1"));
915        let svg = tree.to_svg(None, 200, 200, 50);
916        assert!(svg.find("<svg").is_some());
917        assert!(svg.find("</svg>").is_some());
918        assert!(svg.find("Source").is_some());
919        assert!(svg.find("Wirelength").is_some());
920    }
921
922    #[test]
923    fn test_terminal_sorting_by_distance() {
924        let src = Point::new(0, 0);
925        let terminals = vec![Point::new(10, 0), Point::new(1, 0), Point::new(5, 0)];
926        let router = GlobalRouter::new(src, terminals, None);
927        // Should be sorted by distance: (1,0), (5,0), (10,0)
928        assert_eq!(router.terminal_positions[0], Point::new(1, 0));
929        assert_eq!(router.terminal_positions[1], Point::new(5, 0));
930        assert_eq!(router.terminal_positions[2], Point::new(10, 0));
931    }
932}