atm0s_sdn_network/controller_plane/
neighbours.rs1use std::{
2 collections::{HashMap, VecDeque},
3 net::{IpAddr, SocketAddr},
4 sync::Arc,
5};
6
7use atm0s_sdn_identity::{ConnId, NodeAddr, NodeId, Protocol};
8use sans_io_runtime::TaskSwitcherChild;
9
10use crate::{
11 base::{self, Authorization, ConnectionCtx, HandshakeBuilder, NeighboursControl, NeighboursControlCmds, SecureContext},
12 data_plane::NetPair,
13};
14
15use self::connection::{ConnectionEvent, NeighbourConnection};
16
17mod connection;
18
19pub enum Input {
20 ConnectTo(NodeAddr),
21 DisconnectFrom(NodeId),
22 Control(NetPair, NeighboursControl),
23}
24
25pub enum Output {
26 Control(NetPair, NeighboursControl),
27 Event(base::ConnectionEvent),
28 OnResourceEmpty,
29}
30
31pub struct NeighboursManager {
32 node_id: NodeId,
33 bind_addrs: Vec<SocketAddr>,
34 connections: HashMap<NetPair, NeighbourConnection>,
35 neighbours: HashMap<ConnId, ConnectionCtx>,
36 queue: VecDeque<Output>,
37 shutdown: bool,
38 authorization: Arc<dyn Authorization>,
39 handshake_builder: Arc<dyn HandshakeBuilder>,
40 random: Box<dyn rand::RngCore>,
41}
42
43impl NeighboursManager {
44 pub fn new(node_id: NodeId, bind_addrs: Vec<SocketAddr>, authorization: Arc<dyn Authorization>, handshake_builder: Arc<dyn HandshakeBuilder>, random: Box<dyn rand::RngCore>) -> Self {
45 Self {
46 node_id,
47 bind_addrs,
48 connections: HashMap::new(),
49 neighbours: HashMap::new(),
50 queue: VecDeque::new(),
51 shutdown: false,
52 authorization,
53 handshake_builder,
54 random,
55 }
56 }
57
58 pub fn conn(&self, conn: ConnId) -> Option<&ConnectionCtx> {
59 self.neighbours.get(&conn)
60 }
61
62 pub fn on_tick(&mut self, now_ms: u64, _tick_count: u64) {
63 for conn in self.connections.values_mut() {
64 conn.on_tick(now_ms);
65 }
66 }
67
68 pub fn on_input(&mut self, now_ms: u64, input: Input) {
69 match input {
70 Input::ConnectTo(addr) => {
71 if addr.node_id() == self.node_id {
72 log::warn!("[Neighbours] Attempt to connect to self");
73 return;
74 }
75 let dest_node = addr.node_id();
76 let dests = get_node_addr_dests(addr);
77 for local in &self.bind_addrs {
78 for remote in &dests {
79 if local.is_ipv4() != remote.is_ipv4() {
80 continue;
81 }
82
83 let pair = NetPair::new(*local, *remote);
84 if self.connections.contains_key(&pair) {
85 continue;
86 }
87 log::info!("[Neighbours] Sending connect request from {local} to {remote}, dest_node {dest_node}");
88 let session_id = self.random.next_u64();
89 let conn = NeighbourConnection::new_outgoing(self.handshake_builder.clone(), self.node_id, dest_node, session_id, pair, now_ms);
90 self.queue.push_back(Output::Event(base::ConnectionEvent::Connecting(conn.ctx())));
91 self.connections.insert(pair, conn);
92 }
93 }
94 }
95 Input::DisconnectFrom(node) => {
96 for conn in self.connections.values_mut() {
97 if conn.dest_node() == node {
98 conn.disconnect(now_ms);
99 }
100 }
101 }
102 Input::Control(addr, control) => {
103 let cmd: NeighboursControlCmds = match control.validate(now_ms, &*self.authorization) {
104 Ok(cmd) => cmd,
105 Err(_) => {
106 log::warn!("[Neighbours] Invalid control from {:?}", addr);
107 return;
108 }
109 };
110
111 log::debug!("[NeighboursManager] received Control(addr: {:?}, cmd: {:?})", addr, cmd);
112 if let Some(conn) = self.connections.get_mut(&addr) {
113 conn.on_input(now_ms, control.from, cmd);
114 } else {
115 match cmd {
116 NeighboursControlCmds::ConnectRequest { session, .. } => {
117 let mut conn = NeighbourConnection::new_incoming(self.handshake_builder.clone(), self.node_id, control.from, session, addr, now_ms);
118 conn.on_input(now_ms, control.from, cmd);
119 self.queue.push_back(Output::Event(base::ConnectionEvent::Connecting(conn.ctx())));
120 self.connections.insert(addr, conn);
121 }
122 _ => {
123 log::warn!("[Neighbours] Neighbour connection not found for control {:?}", control);
124 }
125 }
126 }
127 }
128 }
129 }
130
131 pub fn on_shutdown(&mut self, now_ms: u64) {
132 if self.shutdown {
133 return;
134 }
135 self.shutdown = true;
136 for conn in self.connections.values_mut() {
137 conn.disconnect(now_ms);
138 }
139 }
140}
141
142impl TaskSwitcherChild<Output> for NeighboursManager {
143 type Time = u64;
144
145 fn empty_event(&self) -> Output {
146 Output::OnResourceEmpty
147 }
148
149 fn is_empty(&self) -> bool {
150 self.shutdown && self.connections.is_empty() && self.queue.is_empty()
151 }
152
153 fn pop_output(&mut self, _now: u64) -> Option<Output> {
154 if let Some(output) = self.queue.pop_front() {
155 return Some(output);
156 }
157
158 let mut to_remove = Vec::new();
159 for (remote, conn) in self.connections.iter_mut() {
160 while let Some(output) = conn.pop_output() {
161 match output {
162 connection::Output::Event(event) => {
163 let event = match event {
164 ConnectionEvent::Connected(encryptor, decryptor) => {
165 let ctx = conn.ctx();
166 self.neighbours.insert(ctx.conn, ctx.clone());
167 Some(base::ConnectionEvent::Connected(ctx, SecureContext { encryptor, decryptor }))
168 }
169 ConnectionEvent::ConnectError(err) => {
170 to_remove.push(*remote);
171 Some(base::ConnectionEvent::ConnectError(conn.ctx(), err))
172 }
173 ConnectionEvent::Stats(stats) => {
174 let ctx = conn.ctx();
175 Some(base::ConnectionEvent::Stats(ctx, stats))
176 }
177 ConnectionEvent::Disconnected => {
178 let ctx = conn.ctx();
179 self.neighbours.remove(&ctx.conn);
180 to_remove.push(*remote);
181 Some(base::ConnectionEvent::Disconnected(ctx))
182 }
183 };
184 if let Some(event) = event {
185 self.queue.push_back(Output::Event(event));
186 }
187 }
188 connection::Output::Net(now_ms, remote, cmd) => {
189 log::debug!("[NeighboursManager] pop_output Net(remote: {:?}, cmd: {:?})", remote, cmd);
190 self.queue.push_back(Output::Control(remote, NeighboursControl::build(now_ms, self.node_id, cmd, &*self.authorization)));
191 }
192 }
193 }
194 }
195
196 for remote in to_remove {
197 self.connections.remove(&remote);
198 }
199
200 self.queue.pop_front()
201 }
202}
203
204fn get_node_addr_dests(addr: NodeAddr) -> Vec<SocketAddr> {
205 let mut dests = Vec::new();
206 log::info!("Connect to: addr {}", addr);
207 let mut dest_ip = None;
208 for part in addr.multiaddr().iter() {
209 match part {
210 Protocol::Ip4(i) => {
211 dest_ip = Some(IpAddr::V4(i));
212 }
213 Protocol::Ip6(i) => {
214 dest_ip = Some(IpAddr::V6(i));
215 }
216 Protocol::Udp(port) => {
217 if let Some(ip) = dest_ip {
218 dests.push(SocketAddr::new(ip, port));
219 }
220 }
221 _ => {}
222 }
223 }
224 dests
225}