calyx_opt/analysis/
graph_coloring.rs1use calyx_utils::{Idx, WeightGraph};
2use itertools::Itertools;
3use petgraph::algo;
4use std::{
5    collections::{BTreeMap, HashMap},
6    hash::Hash,
7};
8
9pub struct GraphColoring<T> {
11    graph: WeightGraph<T>,
12    color_freq_map: HashMap<Idx, i64>,
16}
17
18impl<T, C> From<C> for GraphColoring<T>
19where
20    T: Hash + Eq + Ord,
21    C: Iterator<Item = T>,
22{
23    fn from(nodes: C) -> Self {
24        let graph = WeightGraph::from(nodes);
25        GraphColoring {
26            graph,
27            color_freq_map: HashMap::new(),
28        }
29    }
30}
31
32impl<'a, T> GraphColoring<T>
33where
34    T: 'a + Eq + Hash + Clone + Ord,
35{
36    #[inline(always)]
38    pub fn insert_conflict(&mut self, a: &T, b: &T) {
39        self.graph.add_edge(a, b);
40    }
41
42    pub fn insert_conflicts<C>(&mut self, items: C)
44    where
45        C: Iterator<Item = &'a T> + Clone,
46    {
47        self.graph.add_all_edges(items)
48    }
49
50    pub fn has_nodes(&self) -> bool {
51        self.graph.graph.node_count() > 0
52    }
53
54    fn increase_freq(&mut self, idx: Idx) {
56        self.color_freq_map
57            .entry(idx)
58            .and_modify(|v| *v += 1)
59            .or_insert(1);
60    }
61
62    pub fn get_share_freqs(&mut self) -> HashMap<i64, i64> {
64        let mut pdf: HashMap<i64, i64> = HashMap::new();
65        for value in self.color_freq_map.values() {
67            pdf.entry(*value).and_modify(|v| *v += 1).or_insert(1);
71        }
72        pdf
73    }
74
75    pub fn color_greedy(
80        &mut self,
81        bound: Option<i64>,
82        keep_self_color: bool,
83    ) -> HashMap<T, T> {
84        let mut all_colors: BTreeMap<Idx, i64> = BTreeMap::new();
85        let mut coloring: HashMap<Idx, Idx> = HashMap::new();
86        let always_share = bound.is_none();
87        let bound_if_exists = if always_share { 0 } else { bound.unwrap() };
89
90        let sccs = algo::tarjan_scc(&self.graph.graph);
92        for scc in sccs.into_iter().sorted_by(|a, b| b.len().cmp(&a.len())) {
94            let is_complete = scc.iter().all(|&idx| {
96                self.graph.graph.neighbors(idx).count() == scc.len() - 1
97            });
98            if is_complete {
101                let mut available_colors: Vec<_> =
102                    all_colors.keys().cloned().collect_vec();
103
104                for nidx in scc.into_iter().sorted() {
106                    if !available_colors.is_empty() {
107                        let c = available_colors.remove(0);
108                        coloring.insert(nidx, c);
109                        self.increase_freq(c);
110                        if let Some(num_used) = all_colors.get_mut(&c) {
111                            *num_used += 1;
112                            if !always_share && *num_used == bound_if_exists {
113                                all_colors.remove(&c);
114                            }
115                        }
116                    } else {
117                        all_colors.insert(nidx, 1);
118                        coloring.insert(nidx, nidx);
119                        self.increase_freq(nidx);
120                        if !always_share && bound_if_exists == 1 {
121                            all_colors.remove(&nidx);
122                        }
123                    }
124                }
125            } else {
126                for nidx in scc.into_iter().sorted() {
127                    let mut available_colors = all_colors.clone();
128                    for item in self.graph.graph.neighbors(nidx) {
130                        if coloring.contains_key(&item) {
132                            available_colors.remove(&coloring[&item]);
134                        }
135                    }
136
137                    let color = available_colors.iter().next();
138                    match color {
139                        Some((c, _)) => {
140                            coloring.insert(nidx, *c);
141                            self.increase_freq(*c);
142                            if let Some(num_used) = all_colors.get_mut(c) {
143                                *num_used += 1;
144                                if !always_share && *num_used == bound_if_exists
145                                {
146                                    all_colors.remove(c);
147                                }
148                            }
149                        }
150                        None => {
151                            all_colors.insert(nidx, 1);
153                            coloring.insert(nidx, nidx);
154                            self.increase_freq(nidx);
155                            if !always_share && bound_if_exists == 1 {
156                                all_colors.remove(&nidx);
157                            }
158                        }
159                    };
160                }
161            }
162        }
163
164        let rev_map = self.graph.reverse_index();
165        coloring
166            .into_iter()
167            .map(|(n1, n2)| (rev_map[&n1].clone(), rev_map[&n2].clone()))
168            .filter(|(a, b)| (a != b) || keep_self_color)
169            .collect()
170    }
171
172    pub fn welsh_powell_coloring(&self) -> HashMap<T, T> {
173        let mut coloring: HashMap<T, T> = HashMap::new();
174
175        let mut degree_ordering: Vec<&T> = self
176            .graph
177            .nodes()
178            .sorted()
179            .sorted_by(|a, b| self.graph.degree(b).cmp(&self.graph.degree(a)))
180            .collect();
181
182        let rev_map = self.graph.reverse_index();
183        while !degree_ordering.is_empty() {
184            let head = degree_ordering.remove(0);
185            if !coloring.contains_key(head) {
187                coloring.insert(head.clone(), head.clone());
188                for &node in °ree_ordering {
189                    if coloring.contains_key(node) {
190                        continue;
191                    }
192                    if !self
193                        .graph
194                        .graph
195                        .neighbors(self.graph.index_map[node])
196                        .any(|x| coloring.get(&rev_map[&x]) == Some(head))
197                    {
198                        coloring.insert(node.clone(), head.clone());
199                    }
200                }
201            }
202        }
203
204        coloring
205    }
206}
207
208impl<T: Eq + Hash + ToString + Clone + Ord> ToString for GraphColoring<T> {
209    fn to_string(&self) -> String {
210        self.graph.to_string()
211    }
212}