1use std::{
2 collections::HashMap,
3 fmt::Debug,
4 net::{IpAddr, SocketAddr},
5 time::Instant,
6};
7
8use bitfold_core::{
9 config::Config,
10 interceptor::{Interceptor, NoOpInterceptor},
11 packet_pool::PacketAllocator,
12 transport::Socket as TransportSocket,
13};
14use crossbeam_channel::{unbounded, Receiver, Sender};
15use tracing::error;
16
17use super::{
18 event_types::Action,
19 session::{Session, SessionEventAddress},
20};
21
22trait EventSink<E> {
28 fn send(&mut self, event: E);
29}
30
31#[derive(Debug)]
33struct ChannelSink<E>(Sender<E>);
34
35impl<E> ChannelSink<E> {
36 fn new(sender: Sender<E>) -> Self {
37 Self(sender)
38 }
39}
40
41impl<E> EventSink<E> for ChannelSink<E> {
42 fn send(&mut self, event: E) {
43 self.0.send(event).expect("Receiver must exist");
44 }
45}
46
47struct SocketEventSenderAndConfig<TSocket: TransportSocket, ReceiveEvent: Debug> {
48 config: Config,
49 socket: TSocket,
50 event_sender: ChannelSink<ReceiveEvent>,
51 pending_sends: Vec<(SocketAddr, Vec<u8>)>,
52 pending_events: Vec<ReceiveEvent>,
53 interceptor: Box<dyn Interceptor>,
54 send_pool: PacketAllocator,
56}
57
58impl<TSocket: TransportSocket, ReceiveEvent: Debug> std::fmt::Debug
59 for SocketEventSenderAndConfig<TSocket, ReceiveEvent>
60{
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 f.debug_struct("SocketEventSenderAndConfig")
63 .field("config", &self.config)
64 .field("socket", &"<socket>")
65 .field("event_sender", &self.event_sender)
66 .field("pending_sends", &self.pending_sends)
67 .field("pending_events", &self.pending_events)
68 .field("interceptor", &"<interceptor>")
69 .finish()
70 }
71}
72
73impl<TSocket: TransportSocket, ReceiveEvent: Debug>
74 SocketEventSenderAndConfig<TSocket, ReceiveEvent>
75{
76 fn new(
77 config: Config,
78 socket: TSocket,
79 event_sender: Sender<ReceiveEvent>,
80 interceptor: Box<dyn Interceptor>,
81 ) -> Self {
82 let pool = PacketAllocator::new(config.max_packet_size, 256);
84 Self {
85 config,
86 socket,
87 event_sender: ChannelSink::new(event_sender),
88 pending_sends: Vec::new(),
89 pending_events: Vec::new(),
90 interceptor,
91 send_pool: pool,
92 }
93 }
94
95 fn handle_actions(&mut self, address: &SocketAddr, actions: Vec<Action<ReceiveEvent>>) {
96 for action in actions {
97 match action {
98 Action::Send(bytes) => self.pending_sends.push((*address, bytes)),
99 Action::Emit(ev) => self.pending_events.push(ev),
100 }
101 }
102 }
103
104 fn flush(&mut self) {
105 for (addr, mut payload) in self.pending_sends.drain(..) {
106 if !self.interceptor.on_send(&addr, &mut payload) {
108 self.send_pool.deallocate(payload);
111 continue;
112 }
113
114 if let Err(err) = self.socket.send_packet(&addr, &payload) {
115 error!("Error occured sending a packet (to {}): {}", addr, err)
116 }
117 self.send_pool.deallocate(payload);
119 }
120 for event in self.pending_events.drain(..) {
121 self.event_sender.send(event);
122 }
123 }
124}
125
126#[derive(Debug)]
128pub struct SessionManager<TSocket: TransportSocket, TSession: Session> {
129 sessions: HashMap<SocketAddr, TSession>,
130 receive_buffer: Vec<u8>,
131 user_event_receiver: Receiver<TSession::SendEvent>,
132 messenger: SocketEventSenderAndConfig<TSocket, TSession::ReceiveEvent>,
133 event_receiver: Receiver<TSession::ReceiveEvent>,
134 user_event_sender: Sender<TSession::SendEvent>,
135 max_unestablished_sessions: u16,
136 duplicate_peer_count: HashMap<IpAddr, usize>,
138 max_duplicate_peers: u16,
140}
141
142impl<TSocket: TransportSocket, TSession: Session> SessionManager<TSocket, TSession> {
143 pub fn new(socket: TSocket, config: Config) -> Self {
145 Self::new_with_interceptor(socket, config, None)
146 }
147
148 pub fn new_with_interceptor(
150 socket: TSocket,
151 config: Config,
152 interceptor: Option<Box<dyn Interceptor>>,
153 ) -> Self {
154 let (event_sender, event_receiver) = unbounded();
155 let (user_event_sender, user_event_receiver) = unbounded();
156 let max_unestablished_sessions = config.max_unestablished_connections;
157 let max_duplicate_peers = config.max_duplicate_peers;
158
159 let interceptor = interceptor.unwrap_or_else(|| Box::new(NoOpInterceptor));
160
161 SessionManager {
162 receive_buffer: vec![0; config.receive_buffer_max_size],
163 sessions: Default::default(),
164 user_event_receiver,
165 messenger: SocketEventSenderAndConfig::new(config, socket, event_sender, interceptor),
166 user_event_sender,
167 event_receiver,
168 max_unestablished_sessions,
169 duplicate_peer_count: HashMap::new(),
170 max_duplicate_peers,
171 }
172 }
173
174 pub fn manual_poll(&mut self, time: Instant) {
176 let mut unestablished_sessions = self.unestablished_session_count();
177
178 loop {
179 match self.messenger.socket.receive_packet(self.receive_buffer.as_mut()) {
180 Ok((payload, address)) => {
181 let payload_len = payload.len();
182
183 let should_process = {
185 let buf_slice = &mut self.receive_buffer[..payload_len];
186 self.messenger.interceptor.on_receive(&address, buf_slice)
187 };
188
189 if !should_process {
190 continue;
192 }
193
194 let payload = &self.receive_buffer[..payload_len];
196
197 if let Some(session) = self.sessions.get_mut(&address) {
198 let was_est = session.is_established();
199 let actions = session.process_packet(payload, time);
200 self.messenger.handle_actions(&address, actions);
201 if !was_est && session.is_established() {
202 unestablished_sessions -= 1;
203 }
204 } else {
205 let mut session =
206 TSession::create_session(&self.messenger.config, address, time);
207 let actions = session.process_packet(payload, time);
208 self.messenger.handle_actions(&address, actions);
209 if unestablished_sessions < self.max_unestablished_sessions as usize
211 && self.can_accept_duplicate(&address)
212 {
213 self.sessions.insert(address, session);
214 self.increment_duplicate_count(&address);
215 unestablished_sessions += 1;
216 }
217 }
218 }
219 Err(e) => {
220 if e.kind() != std::io::ErrorKind::WouldBlock {
221 error!("Encountered an error receiving data: {:?}", e);
222 }
223 break;
224 }
225 }
226 if self.messenger.socket.is_blocking_mode() {
227 break;
228 }
229 }
230
231 while let Ok(event) = self.user_event_receiver.try_recv() {
232 let addr = event.address();
233
234 let is_new_session = !self.sessions.contains_key(&addr);
236 let can_create = !is_new_session || self.can_accept_duplicate(&addr);
237
238 if is_new_session && !can_create {
240 continue;
241 }
242
243 use std::collections::hash_map::Entry;
245 match self.sessions.entry(addr) {
246 Entry::Occupied(mut entry) => {
247 let session = entry.get_mut();
248 let was_est = session.is_established();
249 let actions = session.process_event(event, time);
250 self.messenger.handle_actions(&addr, actions);
251 if !was_est && session.is_established() {
252 unestablished_sessions -= 1;
253 }
254 }
255 Entry::Vacant(entry) => {
256 let mut session = TSession::create_session(&self.messenger.config, addr, time);
257 let actions = session.process_event(event, time);
258 entry.insert(session);
259 self.messenger.handle_actions(&addr, actions);
260 self.increment_duplicate_count(&addr);
261 }
262 }
263 }
264
265 for (addr, session) in self.sessions.iter_mut() {
266 let actions = session.update(time);
267 self.messenger.handle_actions(addr, actions);
268 }
269
270 let mut to_drop = Vec::new();
272 for (addr, session) in self.sessions.iter_mut() {
273 let (drop, actions) = session.should_drop(time);
274 self.messenger.handle_actions(addr, actions);
275 if drop {
276 to_drop.push(*addr);
277 }
278 }
279
280 for addr in to_drop {
282 self.sessions.remove(&addr);
283 self.decrement_duplicate_count(&addr);
284 }
285
286 self.messenger.flush();
287 }
288
289 pub fn event_sender(&self) -> &Sender<TSession::SendEvent> {
291 &self.user_event_sender
292 }
293
294 pub fn event_receiver(&self) -> &Receiver<TSession::ReceiveEvent> {
296 &self.event_receiver
297 }
298
299 pub fn socket(&self) -> &TSocket {
301 &self.messenger.socket
302 }
303
304 fn unestablished_session_count(&self) -> usize {
305 self.sessions.iter().filter(|s| !s.1.is_established()).count()
306 }
307
308 #[allow(dead_code)]
309 pub fn socket_mut(&mut self) -> &mut TSocket {
311 &mut self.messenger.socket
312 }
313
314 pub fn sessions_count(&self) -> usize {
316 self.sessions.len()
317 }
318
319 pub fn session_mut(&mut self, addr: &SocketAddr) -> Option<&mut TSession> {
321 self.sessions.get_mut(addr)
322 }
323
324 pub fn established_sessions(&self) -> impl Iterator<Item = &SocketAddr> {
326 self.sessions.iter().filter(|(_, s)| s.is_established()).map(|(addr, _)| addr)
327 }
328
329 pub fn established_sessions_count(&self) -> usize {
331 self.sessions.iter().filter(|(_, s)| s.is_established()).count()
332 }
333
334 fn increment_duplicate_count(&mut self, addr: &SocketAddr) {
336 let ip = addr.ip();
337 *self.duplicate_peer_count.entry(ip).or_insert(0) += 1;
338 }
339
340 fn decrement_duplicate_count(&mut self, addr: &SocketAddr) {
343 let ip = addr.ip();
344 if let Some(count) = self.duplicate_peer_count.get_mut(&ip) {
345 *count -= 1;
346 if *count == 0 {
347 self.duplicate_peer_count.remove(&ip);
348 }
349 }
350 }
351
352 fn can_accept_duplicate(&self, addr: &SocketAddr) -> bool {
355 if self.max_duplicate_peers == 0 {
357 return true;
358 }
359
360 let ip = addr.ip();
361 let current_count = self.duplicate_peer_count.get(&ip).copied().unwrap_or(0);
362 current_count < self.max_duplicate_peers as usize
363 }
364
365 pub fn duplicate_peer_count(&self, addr: &SocketAddr) -> usize {
367 self.duplicate_peer_count.get(&addr.ip()).copied().unwrap_or(0)
368 }
369}