Skip to main content

bifrostlink/
route.rs

1use std::{
2	collections::{hash_map::Entry, HashMap, HashSet},
3	hash::Hash,
4};
5
6use serde::{Deserialize, Serialize};
7use tokio::sync::mpsc::UnboundedSender as Sender;
8use tracing::warn;
9
10use crate::{event::RootEvent, AddressT};
11
12#[derive(PartialEq, Eq, Clone, Hash, Debug)]
13pub enum Via<Address> {
14	Address(Address),
15	Direct,
16}
17impl<Address> Via<Address>
18where
19	Address: Clone,
20{
21	fn as_address(&self) -> Option<Address> {
22		match self {
23			Via::Address(a) => Some(a.clone()),
24			Via::Direct => None,
25		}
26	}
27}
28
29#[derive(PartialOrd, Ord, PartialEq, Eq, Clone, Serialize, Deserialize, Copy, Debug)]
30pub struct Rtt(pub u32);
31
32#[derive(PartialEq, Clone, Debug)]
33pub struct MinRtt<Address> {
34	pub via: Via<Address>,
35	pub rtt: Rtt,
36	pub second_best: Option<Rtt>,
37}
38
39#[derive(Debug)]
40pub struct MinRttUpdated<Address> {
41	pub for_address: Address,
42	pub rtt: MinRtt<Address>,
43	pub first_changed: bool,
44	pub second_changed: bool,
45}
46
47#[derive(Debug)]
48pub struct ViaListSeconded<Address> {
49	pub for_connection: Address,
50	pub initial_via: Via<Address>,
51	pub added_via: Via<Address>,
52	pub rtt: Rtt,
53}
54
55#[derive(Debug)]
56pub struct ViaListUnseconded<Address> {
57	pub for_connection: Address,
58	pub only_via: Via<Address>,
59}
60
61#[derive(Debug)]
62pub struct ConnectionAdded<Address> {
63	pub to: Address,
64	pub via: Via<Address>,
65	pub rtt: Rtt,
66}
67#[derive(Debug)]
68pub struct ConnectionRemoved<Address> {
69	pub to: Address,
70	pub via: Via<Address>,
71}
72
73struct AddressData<Address> {
74	// address: Address,
75	via: HashMap<Via<Address>, Rtt>,
76	min_rtt: MinRtt<Address>,
77}
78impl<Address> AddressData<Address>
79where
80	Address: Clone + PartialEq,
81{
82	fn update_min_rtt(&mut self, for_address: Address, sender: &mut Sender<RootEvent<Address>>) {
83		let (via, rtt) = self
84			.via
85			.iter()
86			.min_by_key(|(_, rtt)| **rtt)
87			.expect("updated address with no routes");
88		let second_best = self
89			.via
90			.iter()
91			.filter(|(second, _)| second != &via)
92			.map(|(_, rtt)| rtt)
93			.min();
94
95		let old = &self.min_rtt;
96		let new = MinRtt {
97			via: via.clone(),
98			rtt: *rtt,
99			second_best: second_best.cloned(),
100		};
101
102		if old == &new {
103			return;
104		}
105
106		let only_first_updated = &old.rtt != rtt;
107		let only_second_updated = old.second_best.as_ref() != second_best;
108		let min_via_updated = &old.via != via;
109
110		if sender
111			.send(
112				MinRttUpdated {
113					for_address,
114					rtt: new.clone(),
115					first_changed: only_first_updated || min_via_updated,
116					second_changed: only_second_updated || min_via_updated,
117				}
118				.into(),
119			)
120			.is_err()
121		{
122			warn!("no handlers for min rtt update")
123		}
124
125		self.min_rtt = new;
126	}
127}
128
129#[derive(derivative::Derivative)]
130#[derivative(Default(bound = ""))]
131struct InverseRouteSet<Address> {
132	vias: HashMap<Via<Address>, HashSet<Address>>,
133}
134impl<Address> InverseRouteSet<Address>
135where
136	Address: AddressT,
137{
138	fn inc(&mut self, via: Via<Address>, to: Address) {
139		let routes = self.vias.entry(via).or_default();
140		assert!(routes.insert(to), "inverse imbalance (double inc)");
141	}
142	fn dec(&mut self, via: Via<Address>, to: Address) {
143		let routes = self
144			.vias
145			.get_mut(&via)
146			.expect("inverse imbalance (unknown dec)");
147		assert!(routes.remove(&to), "inverse imbalance (double dec route)");
148		if routes.is_empty() {
149			self.vias.remove(&via);
150		}
151	}
152	fn forwarded(&self, via: Via<Address>) -> Option<impl Iterator<Item = Address> + '_> {
153		let routes = self.vias.get(&via)?;
154		Some(routes.iter().cloned())
155	}
156}
157
158pub struct RouteSet<Address> {
159	routes: HashMap<Address, AddressData<Address>>,
160	inverse: InverseRouteSet<Address>,
161	event: Sender<RootEvent<Address>>,
162}
163
164impl<Address> RouteSet<Address>
165where
166	Address: AddressT,
167{
168	pub fn new(tx: Sender<RootEvent<Address>>) -> Self {
169		Self {
170			routes: Default::default(),
171			inverse: Default::default(),
172			event: tx,
173		}
174	}
175	pub fn inc(&mut self, address: Address, via: Via<Address>, rtt: Rtt) {
176		match self.routes.entry(address.clone()) {
177			Entry::Occupied(mut v) => {
178				let data = v.get_mut();
179				let seconded_initial = (data.via.len() == 1).then(|| {
180					let (via, rtt) = data.via.iter().next().unwrap();
181					(via.clone(), *rtt)
182				});
183				{
184					let Entry::Vacant(via) = data.via.entry(via.clone()) else {
185						warn!("added duplicate connection: {address:?} via {via:?}");
186						return;
187					};
188					via.insert(rtt);
189				}
190				if let Some((initial_via, initial_rtt)) = seconded_initial {
191					if self
192						.event
193						.send(
194							ViaListSeconded {
195								for_connection: address.clone(),
196								initial_via: initial_via.clone(),
197								added_via: via.clone(),
198								rtt: rtt.min(initial_rtt),
199							}
200							.into(),
201						)
202						.is_err()
203					{
204						warn!("no listener for ViaListSeconded")
205					}
206				}
207				let via = v.key().clone();
208				v.get_mut().update_min_rtt(via, &mut self.event);
209			}
210			Entry::Vacant(v) => {
211				v.insert(AddressData {
212					// address: address.clone(),
213					via: [(via.clone(), rtt)].into_iter().collect(),
214					min_rtt: MinRtt {
215						via: via.clone(),
216						rtt,
217						second_best: None,
218					},
219				});
220				if self
221					.event
222					.send(
223						ConnectionAdded {
224							to: address.clone(),
225							via: via.clone(),
226							rtt,
227						}
228						.into(),
229					)
230					.is_err()
231				{
232					warn!("no listener for ConnectionAdded")
233				}
234			}
235		}
236		self.inverse.inc(via, address)
237	}
238	pub fn dec(&mut self, address: Address, via: Via<Address>) {
239		let Some(data) = self.routes.get_mut(&address) else {
240			warn!("removed unknown connection: {address:?} via {via:?} (there is no routes to the specified address)");
241			return;
242		};
243		if data.via.remove(&via).is_none() {
244			warn!("removed unknown connection: {address:?} via {via:?}");
245			return;
246		}
247		if data.via.is_empty() {
248			self.routes.remove(&address);
249			if self
250				.event
251				.send(
252					ConnectionRemoved {
253						to: address.clone(),
254						via: via.clone(),
255					}
256					.into(),
257				)
258				.is_err()
259			{
260				warn!("no listener for ConnectionRemoved");
261			}
262		} else {
263			if data.via.len() == 1 {
264				let only_via = data.via.keys().next().expect("len == 1").clone();
265				if self
266					.event
267					.send(
268						ViaListUnseconded {
269							for_connection: address.clone(),
270							only_via,
271						}
272						.into(),
273					)
274					.is_err()
275				{
276					warn!("no listener for ConnectionRemoved");
277				}
278			}
279			data.update_min_rtt(address.clone(), &mut self.event);
280		}
281		self.inverse.dec(via, address)
282	}
283	pub fn update(&mut self, address: Address, via: Via<Address>, rtt: Rtt) {
284		let Some(data) = self.routes.get_mut(&address) else {
285			warn!("updated rtt for unknown connection");
286			return;
287		};
288		let Some(viartt) = data.via.get_mut(&via) else {
289			warn!("updated rtt for unknown connection");
290			return;
291		};
292		*viartt = rtt;
293		data.update_min_rtt(address, &mut self.event)
294	}
295	pub fn has(&self, address: Address) -> bool {
296		self.routes.contains_key(&address)
297	}
298	pub fn list(&self) -> impl Iterator<Item = (Address, MinRtt<Address>)> + '_ {
299		self.routes
300			.iter()
301			.map(|(a, d)| (a.clone(), d.min_rtt.clone()))
302	}
303
304	pub fn may_be_forwarder_for(&self, forwarder: Via<Address>, sender: Address) -> bool {
305		if forwarder.as_address().as_ref() == Some(&sender) {
306			return true;
307		}
308		let Some(connections) = self.routes.get(&sender) else {
309			// No connection
310			return false;
311		};
312		connections.via.contains_key(&forwarder)
313	}
314	pub fn forwarder_for(
315		&self,
316		address: Address,
317		blacklist: &HashSet<Via<Address>>,
318	) -> Option<Via<Address>> {
319		let connections = self.routes.get(&address)?;
320		// Has direct connection
321		if connections.via.contains_key(&Via::Direct) {
322			return Some(Via::Direct);
323		}
324
325		// Best possible
326		connections
327			.via
328			.iter()
329			.filter(|(via, _)| !blacklist.contains(via))
330			.min_by_key(|(_, rtt)| **rtt)
331			.map(|(via, _)| via.clone())
332	}
333
334	pub fn on_add_direct_connection(&mut self, address: Address, rtt: Rtt) {
335		self.inc(address, Via::Direct, rtt);
336	}
337	pub fn on_remove_direct_connection(&mut self, address: Address) {
338		self.dec(address, Via::Direct);
339	}
340}