1pub mod entry;
36pub mod iter;
37
38use crate::{
39 entry::{BorrowedEntry, OwnedEntry},
40 iter::{IntoIter, Iter},
41};
42use itertools::Itertools;
43use priority_queue::PriorityQueue;
44use std::{collections::HashMap, hash::Hash};
45
46#[derive(Debug, Clone)]
49pub struct PriorityMatrix<R, C, W>
50where
51 R: Clone + Eq + Hash,
52 C: Clone + Eq + Hash,
53 W: Clone + Ord,
54{
55 entries: PriorityQueue<(R, C), W>,
56 rows: HashMap<R, PriorityQueue<C, W>>,
57 cols: HashMap<C, PriorityQueue<R, W>>,
58}
59
60impl<R, C, W> PriorityMatrix<R, C, W>
61where
62 R: Clone + Eq + Hash,
63 C: Clone + Eq + Hash,
64 W: Clone + Ord,
65{
66 pub fn new(&self) -> Self {
67 Self::default()
68 }
69
70 pub fn insert(&mut self, row: R, col: C, weight: W) -> Option<W> {
71 let prev_weight = self
72 .entries
73 .push((row.clone(), col.clone()), weight.clone());
74 self.rows
75 .entry(row.clone())
76 .or_insert_with(PriorityQueue::default)
77 .push(col.clone(), weight.clone());
78 self.cols
79 .entry(col)
80 .or_insert_with(PriorityQueue::default)
81 .push(row, weight);
82 prev_weight
83 }
84
85 pub fn peek(&self) -> Option<BorrowedEntry<'_, R, C, W>> {
86 let ((row, col), weight) = self.entries.peek()?;
87 Some(BorrowedEntry {
88 row,
89 column: col,
90 weight,
91 })
92 }
93
94 pub fn peek_from_row<'a>(&'a self, row: &'a R) -> Option<BorrowedEntry<'_, R, C, W>> {
95 let (col, _) = self.rows.get(row)?.peek().unwrap();
96 let key = (row.clone(), col.clone());
97 let (_, weight) = self.entries.get(&key).unwrap();
98 Some(BorrowedEntry {
99 row,
100 column: col,
101 weight,
102 })
103 }
104
105 pub fn peek_from_column<'a>(&'a self, col: &'a C) -> Option<BorrowedEntry<'a, R, C, W>> {
106 let (row, _) = self.cols.get(col)?.peek().unwrap();
107 let key = (row.clone(), col.clone());
108 let (_, weight) = self.entries.get(&key).unwrap();
109 Some(BorrowedEntry {
110 row,
111 column: col,
112 weight,
113 })
114 }
115
116 pub fn pop(&mut self) -> Option<OwnedEntry<R, C, W>> {
117 let ((row, col), weight) = self.entries.pop()?;
118 self.rows.get_mut(&row).unwrap().remove(&col);
119 self.rows.get_mut(&row).unwrap().remove(&col);
120 Some(OwnedEntry {
121 row,
122 column: col,
123 weight,
124 })
125 }
126
127 pub fn pop_from_row(&mut self, row: &R) -> Option<OwnedEntry<R, C, W>> {
128 let (col, weight) = self.rows.get_mut(row)?.pop().unwrap();
129 let key = (row.clone(), col.clone());
130 self.entries.remove(&key);
131 self.cols.get_mut(&col).unwrap().remove(row);
132 Some(OwnedEntry {
133 row: row.clone(),
134 column: col,
135 weight,
136 })
137 }
138
139 pub fn pop_from_column(&mut self, col: &C) -> Option<OwnedEntry<R, C, W>> {
140 let (row, weight) = self.cols.get_mut(col)?.pop().unwrap();
141 let key = (row.clone(), col.clone());
142 self.entries.remove(&key);
143 self.rows.get_mut(&row).unwrap().remove(col);
144 Some(OwnedEntry {
145 row,
146 column: col.clone(),
147 weight,
148 })
149 }
150
151 pub fn remove(&mut self, row: &R, col: &C) -> bool {
152 let ok = self.entries.remove(&(row.clone(), col.clone())).is_some();
153 if !ok {
154 return false;
155 }
156
157 self.rows.get_mut(row).unwrap().remove(col);
158 self.cols.get_mut(col).unwrap().remove(row);
159 true
160 }
161
162 pub fn remove_row(&mut self, row: &R) {
163 self.rows
164 .remove(row)
165 .into_iter()
166 .flatten()
167 .map(|(curr_col, _)| (row.clone(), curr_col))
168 .for_each(|(row, col)| {
169 if let Some(queue) = self.cols.get_mut(&col) {
170 queue.remove(&row);
171 }
172 self.entries.remove(&(row, col));
173 });
174 }
175
176 pub fn remove_column(&mut self, col: &C) {
177 self.cols
178 .remove(col)
179 .into_iter()
180 .flatten()
181 .map(|(curr_row, _)| (curr_row, col.clone()))
182 .for_each(|(row, col)| {
183 if let Some(queue) = self.rows.get_mut(&row) {
184 queue.remove(&col);
185 }
186 self.entries.remove(&(row, col));
187 });
188 }
189
190 pub fn remove_row_and_column(&mut self, row: &R, col: &C) {
191 let row_keys = self
192 .rows
193 .remove(row)
194 .into_iter()
195 .flatten()
196 .map(|(curr_col, _)| (row.clone(), curr_col));
197 let col_keys = self
198 .cols
199 .remove(col)
200 .into_iter()
201 .flatten()
202 .map(|(curr_row, _)| (curr_row, col.clone()));
203 let all_keys = row_keys.chain(col_keys);
204
205 all_keys.for_each(|key| {
206 self.entries.remove(&key);
207 });
208 }
209
210 pub fn iter(&self) -> Iter<'_, R, C, W> {
211 Iter {
212 iter: self.entries.iter(),
213 }
214 }
215
216 pub fn len(&self) -> usize {
218 self.entries.len()
219 }
220
221 pub fn is_empty(&self) -> bool {
222 self.entries.is_empty()
223 }
224
225 pub fn row_keys(&self) -> impl Iterator<Item = &R> {
226 self.rows.keys()
227 }
228
229 pub fn column_keys(&self) -> impl Iterator<Item = &C> {
230 self.cols.keys()
231 }
232}
233
234impl<R, C, W> Default for PriorityMatrix<R, C, W>
235where
236 R: Clone + Eq + Hash,
237 C: Clone + Eq + Hash,
238 W: Clone + Ord,
239{
240 fn default() -> Self {
241 Self {
242 entries: PriorityQueue::new(),
243 rows: HashMap::new(),
244 cols: HashMap::new(),
245 }
246 }
247}
248
249impl<R, C, W> FromIterator<(R, C, W)> for PriorityMatrix<R, C, W>
250where
251 R: Clone + Eq + Hash,
252 C: Clone + Eq + Hash,
253 W: Clone + Ord,
254{
255 fn from_iter<T>(iter: T) -> Self
256 where
257 T: IntoIterator<Item = (R, C, W)>,
258 {
259 let entries: PriorityQueue<(R, C), W> = iter
260 .into_iter()
261 .map(|(row, col, val)| ((row, col), val))
262 .collect();
263 let rows: HashMap<R, PriorityQueue<C, W>> = entries
264 .iter()
265 .map(|((row, col), iou)| (row.clone(), (col.clone(), iou.clone())))
266 .into_grouping_map()
267 .collect();
268 let cols: HashMap<C, PriorityQueue<R, W>> = entries
269 .iter()
270 .map(|((row, col), iou)| (col.clone(), (row.clone(), iou.clone())))
271 .into_grouping_map()
272 .collect();
273
274 PriorityMatrix {
275 entries,
276 rows,
277 cols,
278 }
279 }
280}
281
282impl<R, C, W> IntoIterator for PriorityMatrix<R, C, W>
283where
284 R: Clone + Eq + Hash,
285 C: Clone + Eq + Hash,
286 W: Clone + Ord,
287{
288 type Item = (R, C, W);
289 type IntoIter = IntoIter<R, C, W>;
290
291 fn into_iter(self) -> Self::IntoIter {
292 IntoIter {
293 iter: self.entries.into_iter(),
294 }
295 }
296}