bitfold_host/
session_manager.rs

1use std::{
2    collections::HashMap,
3    fmt::Debug,
4    net::{IpAddr, SocketAddr},
5    time::Instant,
6};
7
8use bitfold_core::{
9    config::Config,
10    interceptor::{Interceptor, NoOpInterceptor},
11    packet_pool::PacketAllocator,
12    transport::Socket as TransportSocket,
13};
14use crossbeam_channel::{unbounded, Receiver, Sender};
15use tracing::error;
16
17use super::{
18    event_types::Action,
19    session::{Session, SessionEventAddress},
20};
21
22// ============================================================================
23// Event Sink (Internal)
24// ============================================================================
25
26/// Minimal event sink abstraction to decouple from a concrete channel.
27trait EventSink<E> {
28    fn send(&mut self, event: E);
29}
30
31/// Channel-backed event sink using crossbeam `Sender`.
32#[derive(Debug)]
33struct ChannelSink<E>(Sender<E>);
34
35impl<E> ChannelSink<E> {
36    fn new(sender: Sender<E>) -> Self {
37        Self(sender)
38    }
39}
40
41impl<E> EventSink<E> for ChannelSink<E> {
42    fn send(&mut self, event: E) {
43        self.0.send(event).expect("Receiver must exist");
44    }
45}
46
47struct SocketEventSenderAndConfig<TSocket: TransportSocket, ReceiveEvent: Debug> {
48    config: Config,
49    socket: TSocket,
50    event_sender: ChannelSink<ReceiveEvent>,
51    pending_sends: Vec<(SocketAddr, Vec<u8>)>,
52    pending_events: Vec<ReceiveEvent>,
53    interceptor: Box<dyn Interceptor>,
54    /// Pool to recycle send buffers and reduce allocations on hot paths
55    send_pool: PacketAllocator,
56}
57
58impl<TSocket: TransportSocket, ReceiveEvent: Debug> std::fmt::Debug
59    for SocketEventSenderAndConfig<TSocket, ReceiveEvent>
60{
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        f.debug_struct("SocketEventSenderAndConfig")
63            .field("config", &self.config)
64            .field("socket", &"<socket>")
65            .field("event_sender", &self.event_sender)
66            .field("pending_sends", &self.pending_sends)
67            .field("pending_events", &self.pending_events)
68            .field("interceptor", &"<interceptor>")
69            .finish()
70    }
71}
72
73impl<TSocket: TransportSocket, ReceiveEvent: Debug>
74    SocketEventSenderAndConfig<TSocket, ReceiveEvent>
75{
76    fn new(
77        config: Config,
78        socket: TSocket,
79        event_sender: Sender<ReceiveEvent>,
80        interceptor: Box<dyn Interceptor>,
81    ) -> Self {
82        // Pre-size pool buffers to typical max packet size; keep a modest pool
83        let pool = PacketAllocator::new(config.max_packet_size, 256);
84        Self {
85            config,
86            socket,
87            event_sender: ChannelSink::new(event_sender),
88            pending_sends: Vec::new(),
89            pending_events: Vec::new(),
90            interceptor,
91            send_pool: pool,
92        }
93    }
94
95    fn handle_actions(&mut self, address: &SocketAddr, actions: Vec<Action<ReceiveEvent>>) {
96        for action in actions {
97            match action {
98                Action::Send(bytes) => self.pending_sends.push((*address, bytes)),
99                Action::Emit(ev) => self.pending_events.push(ev),
100            }
101        }
102    }
103
104    fn flush(&mut self) {
105        for (addr, mut payload) in self.pending_sends.drain(..) {
106            // Call interceptor before sending
107            if !self.interceptor.on_send(&addr, &mut payload) {
108                // Interceptor dropped the packet
109                // Return buffer to pool for reuse
110                self.send_pool.deallocate(payload);
111                continue;
112            }
113
114            if let Err(err) = self.socket.send_packet(&addr, &payload) {
115                error!("Error occured sending a packet (to {}): {}", addr, err)
116            }
117            // Return the buffer to the pool for reuse
118            self.send_pool.deallocate(payload);
119        }
120        for event in self.pending_events.drain(..) {
121            self.event_sender.send(event);
122        }
123    }
124}
125
126/// Session manager over a datagram socket and generic `Session` engine.
127#[derive(Debug)]
128pub struct SessionManager<TSocket: TransportSocket, TSession: Session> {
129    sessions: HashMap<SocketAddr, TSession>,
130    receive_buffer: Vec<u8>,
131    user_event_receiver: Receiver<TSession::SendEvent>,
132    messenger: SocketEventSenderAndConfig<TSocket, TSession::ReceiveEvent>,
133    event_receiver: Receiver<TSession::ReceiveEvent>,
134    user_event_sender: Sender<TSession::SendEvent>,
135    max_unestablished_sessions: u16,
136    /// Tracks the number of connections per IP address for duplicate peer management
137    duplicate_peer_count: HashMap<IpAddr, usize>,
138    /// Maximum number of duplicate peers allowed (0 = unlimited)
139    max_duplicate_peers: u16,
140}
141
142impl<TSocket: TransportSocket, TSession: Session> SessionManager<TSocket, TSession> {
143    /// Creates a new session manager.
144    pub fn new(socket: TSocket, config: Config) -> Self {
145        Self::new_with_interceptor(socket, config, None)
146    }
147
148    /// Creates a new session manager with a custom interceptor.
149    pub fn new_with_interceptor(
150        socket: TSocket,
151        config: Config,
152        interceptor: Option<Box<dyn Interceptor>>,
153    ) -> Self {
154        let (event_sender, event_receiver) = unbounded();
155        let (user_event_sender, user_event_receiver) = unbounded();
156        let max_unestablished_sessions = config.max_unestablished_connections;
157        let max_duplicate_peers = config.max_duplicate_peers;
158
159        let interceptor = interceptor.unwrap_or_else(|| Box::new(NoOpInterceptor));
160
161        SessionManager {
162            receive_buffer: vec![0; config.receive_buffer_max_size],
163            sessions: Default::default(),
164            user_event_receiver,
165            messenger: SocketEventSenderAndConfig::new(config, socket, event_sender, interceptor),
166            user_event_sender,
167            event_receiver,
168            max_unestablished_sessions,
169            duplicate_peer_count: HashMap::new(),
170            max_duplicate_peers,
171        }
172    }
173
174    /// Polls for network I/O and processes all sessions.
175    pub fn manual_poll(&mut self, time: Instant) {
176        let mut unestablished_sessions = self.unestablished_session_count();
177
178        loop {
179            match self.messenger.socket.receive_packet(self.receive_buffer.as_mut()) {
180                Ok((payload, address)) => {
181                    let payload_len = payload.len();
182
183                    // Call interceptor on received data
184                    let should_process = {
185                        let buf_slice = &mut self.receive_buffer[..payload_len];
186                        self.messenger.interceptor.on_receive(&address, buf_slice)
187                    };
188
189                    if !should_process {
190                        // Interceptor dropped the packet
191                        continue;
192                    }
193
194                    // Re-get payload reference after interceptor potentially modified it
195                    let payload = &self.receive_buffer[..payload_len];
196
197                    if let Some(session) = self.sessions.get_mut(&address) {
198                        let was_est = session.is_established();
199                        let actions = session.process_packet(payload, time);
200                        self.messenger.handle_actions(&address, actions);
201                        if !was_est && session.is_established() {
202                            unestablished_sessions -= 1;
203                        }
204                    } else {
205                        let mut session =
206                            TSession::create_session(&self.messenger.config, address, time);
207                        let actions = session.process_packet(payload, time);
208                        self.messenger.handle_actions(&address, actions);
209                        // Check both unestablished limit and duplicate peer limit
210                        if unestablished_sessions < self.max_unestablished_sessions as usize
211                            && self.can_accept_duplicate(&address)
212                        {
213                            self.sessions.insert(address, session);
214                            self.increment_duplicate_count(&address);
215                            unestablished_sessions += 1;
216                        }
217                    }
218                }
219                Err(e) => {
220                    if e.kind() != std::io::ErrorKind::WouldBlock {
221                        error!("Encountered an error receiving data: {:?}", e);
222                    }
223                    break;
224                }
225            }
226            if self.messenger.socket.is_blocking_mode() {
227                break;
228            }
229        }
230
231        while let Ok(event) = self.user_event_receiver.try_recv() {
232            let addr = event.address();
233
234            // Check if session exists and if we can accept a new duplicate
235            let is_new_session = !self.sessions.contains_key(&addr);
236            let can_create = !is_new_session || self.can_accept_duplicate(&addr);
237
238            // Skip if we can't create a new session due to duplicate limit
239            if is_new_session && !can_create {
240                continue;
241            }
242
243            // Use entry API and process in one scope
244            use std::collections::hash_map::Entry;
245            match self.sessions.entry(addr) {
246                Entry::Occupied(mut entry) => {
247                    let session = entry.get_mut();
248                    let was_est = session.is_established();
249                    let actions = session.process_event(event, time);
250                    self.messenger.handle_actions(&addr, actions);
251                    if !was_est && session.is_established() {
252                        unestablished_sessions -= 1;
253                    }
254                }
255                Entry::Vacant(entry) => {
256                    let mut session = TSession::create_session(&self.messenger.config, addr, time);
257                    let actions = session.process_event(event, time);
258                    entry.insert(session);
259                    self.messenger.handle_actions(&addr, actions);
260                    self.increment_duplicate_count(&addr);
261                }
262            }
263        }
264
265        for (addr, session) in self.sessions.iter_mut() {
266            let actions = session.update(time);
267            self.messenger.handle_actions(addr, actions);
268        }
269
270        // Collect addresses to drop
271        let mut to_drop = Vec::new();
272        for (addr, session) in self.sessions.iter_mut() {
273            let (drop, actions) = session.should_drop(time);
274            self.messenger.handle_actions(addr, actions);
275            if drop {
276                to_drop.push(*addr);
277            }
278        }
279
280        // Remove dropped sessions and decrement duplicate counts
281        for addr in to_drop {
282            self.sessions.remove(&addr);
283            self.decrement_duplicate_count(&addr);
284        }
285
286        self.messenger.flush();
287    }
288
289    /// Returns the event sender for sending user events to sessions.
290    pub fn event_sender(&self) -> &Sender<TSession::SendEvent> {
291        &self.user_event_sender
292    }
293
294    /// Returns the event receiver for receiving session events.
295    pub fn event_receiver(&self) -> &Receiver<TSession::ReceiveEvent> {
296        &self.event_receiver
297    }
298
299    /// Returns a reference to the underlying socket.
300    pub fn socket(&self) -> &TSocket {
301        &self.messenger.socket
302    }
303
304    fn unestablished_session_count(&self) -> usize {
305        self.sessions.iter().filter(|s| !s.1.is_established()).count()
306    }
307
308    #[allow(dead_code)]
309    /// Returns a mutable reference to the underlying socket.
310    pub fn socket_mut(&mut self) -> &mut TSocket {
311        &mut self.messenger.socket
312    }
313
314    /// Returns the number of active sessions.
315    pub fn sessions_count(&self) -> usize {
316        self.sessions.len()
317    }
318
319    /// Returns a mutable reference to a specific session by address.
320    pub fn session_mut(&mut self, addr: &SocketAddr) -> Option<&mut TSession> {
321        self.sessions.get_mut(addr)
322    }
323
324    /// Returns an iterator over all established session addresses.
325    pub fn established_sessions(&self) -> impl Iterator<Item = &SocketAddr> {
326        self.sessions.iter().filter(|(_, s)| s.is_established()).map(|(addr, _)| addr)
327    }
328
329    /// Returns the number of established sessions.
330    pub fn established_sessions_count(&self) -> usize {
331        self.sessions.iter().filter(|(_, s)| s.is_established()).count()
332    }
333
334    /// Increments the duplicate peer count for the given address's IP.
335    fn increment_duplicate_count(&mut self, addr: &SocketAddr) {
336        let ip = addr.ip();
337        *self.duplicate_peer_count.entry(ip).or_insert(0) += 1;
338    }
339
340    /// Decrements the duplicate peer count for the given address's IP.
341    /// Removes the entry if count reaches zero.
342    fn decrement_duplicate_count(&mut self, addr: &SocketAddr) {
343        let ip = addr.ip();
344        if let Some(count) = self.duplicate_peer_count.get_mut(&ip) {
345            *count -= 1;
346            if *count == 0 {
347                self.duplicate_peer_count.remove(&ip);
348            }
349        }
350    }
351
352    /// Checks if adding a connection from this address would exceed the duplicate peer limit.
353    /// Returns true if the connection is allowed, false if it would exceed the limit.
354    fn can_accept_duplicate(&self, addr: &SocketAddr) -> bool {
355        // 0 means unlimited duplicates
356        if self.max_duplicate_peers == 0 {
357            return true;
358        }
359
360        let ip = addr.ip();
361        let current_count = self.duplicate_peer_count.get(&ip).copied().unwrap_or(0);
362        current_count < self.max_duplicate_peers as usize
363    }
364
365    /// Returns the number of connections from a specific IP address.
366    pub fn duplicate_peer_count(&self, addr: &SocketAddr) -> usize {
367        self.duplicate_peer_count.get(&addr.ip()).copied().unwrap_or(0)
368    }
369}