1use std::{
2 collections::{hash_map::Entry, HashMap, HashSet},
3 hash::Hash,
4};
5
6use serde::{Deserialize, Serialize};
7use tokio::sync::mpsc::UnboundedSender as Sender;
8use tracing::warn;
9
10use crate::{event::RootEvent, AddressT};
11
12#[derive(PartialEq, Eq, Clone, Hash, Debug)]
13pub enum Via<Address> {
14 Address(Address),
15 Direct,
16}
17impl<Address> Via<Address>
18where
19 Address: Clone,
20{
21 fn as_address(&self) -> Option<Address> {
22 match self {
23 Via::Address(a) => Some(a.clone()),
24 Via::Direct => None,
25 }
26 }
27}
28
29#[derive(PartialOrd, Ord, PartialEq, Eq, Clone, Serialize, Deserialize, Copy, Debug)]
30pub struct Rtt(pub u32);
31
32#[derive(PartialEq, Clone, Debug)]
33pub struct MinRtt<Address> {
34 pub via: Via<Address>,
35 pub rtt: Rtt,
36 pub second_best: Option<Rtt>,
37}
38
39#[derive(Debug)]
40pub struct MinRttUpdated<Address> {
41 pub for_address: Address,
42 pub rtt: MinRtt<Address>,
43 pub first_changed: bool,
44 pub second_changed: bool,
45}
46
47#[derive(Debug)]
48pub struct ViaListSeconded<Address> {
49 pub for_connection: Address,
50 pub initial_via: Via<Address>,
51 pub added_via: Via<Address>,
52 pub rtt: Rtt,
53}
54
55#[derive(Debug)]
56pub struct ViaListUnseconded<Address> {
57 pub for_connection: Address,
58 pub only_via: Via<Address>,
59}
60
61#[derive(Debug)]
62pub struct ConnectionAdded<Address> {
63 pub to: Address,
64 pub via: Via<Address>,
65 pub rtt: Rtt,
66}
67#[derive(Debug)]
68pub struct ConnectionRemoved<Address> {
69 pub to: Address,
70 pub via: Via<Address>,
71}
72
73struct AddressData<Address> {
74 via: HashMap<Via<Address>, Rtt>,
76 min_rtt: MinRtt<Address>,
77}
78impl<Address> AddressData<Address>
79where
80 Address: Clone + PartialEq,
81{
82 fn update_min_rtt(&mut self, for_address: Address, sender: &mut Sender<RootEvent<Address>>) {
83 let (via, rtt) = self
84 .via
85 .iter()
86 .min_by_key(|(_, rtt)| **rtt)
87 .expect("updated address with no routes");
88 let second_best = self
89 .via
90 .iter()
91 .filter(|(second, _)| second != &via)
92 .map(|(_, rtt)| rtt)
93 .min();
94
95 let old = &self.min_rtt;
96 let new = MinRtt {
97 via: via.clone(),
98 rtt: *rtt,
99 second_best: second_best.cloned(),
100 };
101
102 if old == &new {
103 return;
104 }
105
106 let only_first_updated = &old.rtt != rtt;
107 let only_second_updated = old.second_best.as_ref() != second_best;
108 let min_via_updated = &old.via != via;
109
110 if sender
111 .send(
112 MinRttUpdated {
113 for_address,
114 rtt: new.clone(),
115 first_changed: only_first_updated || min_via_updated,
116 second_changed: only_second_updated || min_via_updated,
117 }
118 .into(),
119 )
120 .is_err()
121 {
122 warn!("no handlers for min rtt update")
123 }
124
125 self.min_rtt = new;
126 }
127}
128
129#[derive(derivative::Derivative)]
130#[derivative(Default(bound = ""))]
131struct InverseRouteSet<Address> {
132 vias: HashMap<Via<Address>, HashSet<Address>>,
133}
134impl<Address> InverseRouteSet<Address>
135where
136 Address: AddressT,
137{
138 fn inc(&mut self, via: Via<Address>, to: Address) {
139 let routes = self.vias.entry(via).or_default();
140 assert!(routes.insert(to), "inverse imbalance (double inc)");
141 }
142 fn dec(&mut self, via: Via<Address>, to: Address) {
143 let routes = self
144 .vias
145 .get_mut(&via)
146 .expect("inverse imbalance (unknown dec)");
147 assert!(routes.remove(&to), "inverse imbalance (double dec route)");
148 if routes.is_empty() {
149 self.vias.remove(&via);
150 }
151 }
152 fn forwarded(&self, via: Via<Address>) -> Option<impl Iterator<Item = Address> + '_> {
153 let routes = self.vias.get(&via)?;
154 Some(routes.iter().cloned())
155 }
156}
157
158pub struct RouteSet<Address> {
159 routes: HashMap<Address, AddressData<Address>>,
160 inverse: InverseRouteSet<Address>,
161 event: Sender<RootEvent<Address>>,
162}
163
164impl<Address> RouteSet<Address>
165where
166 Address: AddressT,
167{
168 pub fn new(tx: Sender<RootEvent<Address>>) -> Self {
169 Self {
170 routes: Default::default(),
171 inverse: Default::default(),
172 event: tx,
173 }
174 }
175 pub fn inc(&mut self, address: Address, via: Via<Address>, rtt: Rtt) {
176 match self.routes.entry(address.clone()) {
177 Entry::Occupied(mut v) => {
178 let data = v.get_mut();
179 let seconded_initial = (data.via.len() == 1).then(|| {
180 let (via, rtt) = data.via.iter().next().unwrap();
181 (via.clone(), *rtt)
182 });
183 {
184 let Entry::Vacant(via) = data.via.entry(via.clone()) else {
185 warn!("added duplicate connection: {address:?} via {via:?}");
186 return;
187 };
188 via.insert(rtt);
189 }
190 if let Some((initial_via, initial_rtt)) = seconded_initial {
191 if self
192 .event
193 .send(
194 ViaListSeconded {
195 for_connection: address.clone(),
196 initial_via: initial_via.clone(),
197 added_via: via.clone(),
198 rtt: rtt.min(initial_rtt),
199 }
200 .into(),
201 )
202 .is_err()
203 {
204 warn!("no listener for ViaListSeconded")
205 }
206 }
207 let via = v.key().clone();
208 v.get_mut().update_min_rtt(via, &mut self.event);
209 }
210 Entry::Vacant(v) => {
211 v.insert(AddressData {
212 via: [(via.clone(), rtt)].into_iter().collect(),
214 min_rtt: MinRtt {
215 via: via.clone(),
216 rtt,
217 second_best: None,
218 },
219 });
220 if self
221 .event
222 .send(
223 ConnectionAdded {
224 to: address.clone(),
225 via: via.clone(),
226 rtt,
227 }
228 .into(),
229 )
230 .is_err()
231 {
232 warn!("no listener for ConnectionAdded")
233 }
234 }
235 }
236 self.inverse.inc(via, address)
237 }
238 pub fn dec(&mut self, address: Address, via: Via<Address>) {
239 let Some(data) = self.routes.get_mut(&address) else {
240 warn!("removed unknown connection: {address:?} via {via:?} (there is no routes to the specified address)");
241 return;
242 };
243 if data.via.remove(&via).is_none() {
244 warn!("removed unknown connection: {address:?} via {via:?}");
245 return;
246 }
247 if data.via.is_empty() {
248 self.routes.remove(&address);
249 if self
250 .event
251 .send(
252 ConnectionRemoved {
253 to: address.clone(),
254 via: via.clone(),
255 }
256 .into(),
257 )
258 .is_err()
259 {
260 warn!("no listener for ConnectionRemoved");
261 }
262 } else {
263 if data.via.len() == 1 {
264 let only_via = data.via.keys().next().expect("len == 1").clone();
265 if self
266 .event
267 .send(
268 ViaListUnseconded {
269 for_connection: address.clone(),
270 only_via,
271 }
272 .into(),
273 )
274 .is_err()
275 {
276 warn!("no listener for ConnectionRemoved");
277 }
278 }
279 data.update_min_rtt(address.clone(), &mut self.event);
280 }
281 self.inverse.dec(via, address)
282 }
283 pub fn update(&mut self, address: Address, via: Via<Address>, rtt: Rtt) {
284 let Some(data) = self.routes.get_mut(&address) else {
285 warn!("updated rtt for unknown connection");
286 return;
287 };
288 let Some(viartt) = data.via.get_mut(&via) else {
289 warn!("updated rtt for unknown connection");
290 return;
291 };
292 *viartt = rtt;
293 data.update_min_rtt(address, &mut self.event)
294 }
295 pub fn has(&self, address: Address) -> bool {
296 self.routes.contains_key(&address)
297 }
298 pub fn list(&self) -> impl Iterator<Item = (Address, MinRtt<Address>)> + '_ {
299 self.routes
300 .iter()
301 .map(|(a, d)| (a.clone(), d.min_rtt.clone()))
302 }
303
304 pub fn may_be_forwarder_for(&self, forwarder: Via<Address>, sender: Address) -> bool {
305 if forwarder.as_address().as_ref() == Some(&sender) {
306 return true;
307 }
308 let Some(connections) = self.routes.get(&sender) else {
309 return false;
311 };
312 connections.via.contains_key(&forwarder)
313 }
314 pub fn forwarder_for(
315 &self,
316 address: Address,
317 blacklist: &HashSet<Via<Address>>,
318 ) -> Option<Via<Address>> {
319 let connections = self.routes.get(&address)?;
320 if connections.via.contains_key(&Via::Direct) {
322 return Some(Via::Direct);
323 }
324
325 connections
327 .via
328 .iter()
329 .filter(|(via, _)| !blacklist.contains(via))
330 .min_by_key(|(_, rtt)| **rtt)
331 .map(|(via, _)| via.clone())
332 }
333
334 pub fn on_add_direct_connection(&mut self, address: Address, rtt: Rtt) {
335 self.inc(address, Via::Direct, rtt);
336 }
337 pub fn on_remove_direct_connection(&mut self, address: Address) {
338 self.dec(address, Via::Direct);
339 }
340}