bitfold_host/
peer_session.rs

1use std::{net::SocketAddr, time::Instant};
2
3use bitfold_core::error::ErrorKind;
4use bitfold_peer::{Peer, PeerState};
5use bitfold_protocol::packet::{DeliveryGuarantee, OrderingGuarantee, Packet};
6use tracing::error;
7
8use super::{
9    event_types::{Action, SocketEvent},
10    session::{Session, SessionEventAddress},
11};
12
13/// Required by `SessionManager` to properly handle session event.
14impl SessionEventAddress for SocketEvent {
15    /// Returns event address.
16    fn address(&self) -> SocketAddr {
17        match self {
18            SocketEvent::Packet(packet) => packet.addr(),
19            SocketEvent::Connect(addr) => *addr,
20            SocketEvent::Timeout(addr) => *addr,
21            SocketEvent::Disconnect(addr) => *addr,
22        }
23    }
24}
25
26/// Required by `SessionManager` to properly handle user event.
27impl SessionEventAddress for Packet {
28    /// Returns event address.
29    fn address(&self) -> SocketAddr {
30        self.addr()
31    }
32}
33
34impl Session for Peer {
35    type SendEvent = Packet;
36    type ReceiveEvent = SocketEvent;
37
38    fn create_session(
39        config: &bitfold_core::config::Config,
40        address: SocketAddr,
41        time: Instant,
42    ) -> Peer {
43        Peer::new(address, config, time)
44    }
45
46    fn is_established(&self) -> bool {
47        self.is_established()
48    }
49
50    fn should_drop(&mut self, time: Instant) -> (bool, Vec<Action<Self::ReceiveEvent>>) {
51        let mut actions = Vec::new();
52
53        // Check if peer received disconnect command (zombie state)
54        if self.state() == PeerState::Zombie {
55            actions.push(Action::Emit(SocketEvent::Disconnect(self.remote_address)));
56            return (true, actions);
57        }
58
59        // Check for timeout or too many packets in flight
60        let should_drop = self.packets_in_flight() > self.config().max_packets_in_flight
61            || self.last_heard(time) >= self.config().idle_connection_timeout;
62
63        if should_drop {
64            actions.push(Action::Emit(SocketEvent::Timeout(self.remote_address)));
65            if self.is_established() {
66                actions.push(Action::Emit(SocketEvent::Disconnect(self.remote_address)));
67            }
68        }
69        (should_drop, actions)
70    }
71
72    fn process_packet(&mut self, payload: &[u8], time: Instant) -> Vec<Action<Self::ReceiveEvent>> {
73        let mut actions = Vec::new();
74        if !payload.is_empty() {
75            // Update inbound bandwidth window and enforce incoming bandwidth limit
76            self.update_bandwidth_window(time);
77            if !self.can_receive_within_bandwidth() {
78                // Over incoming bandwidth limit: drop packet for this window
79                tracing::warn!(
80                    "Dropping packet ({} bytes) from {} due to incoming bandwidth limit (utilization {:.2})",
81                    payload.len(),
82                    self.remote_address,
83                    self.incoming_bandwidth_utilization()
84                );
85                return actions; // No actions emitted
86            }
87
88            // Track bytes received for bandwidth monitoring (only after passing the limit check)
89            self.record_bytes_received(payload.len() as u32);
90
91            // Process command packet
92            match self.process_command_packet(payload, time) {
93                Ok(packets) => {
94                    if self.record_recv() {
95                        actions.push(Action::Emit(SocketEvent::Connect(self.remote_address)));
96                    }
97                    for incoming in packets {
98                        actions.push(Action::Emit(SocketEvent::Packet(incoming.0)));
99                    }
100                }
101                Err(err) => error!("Error occurred processing command packet: {:?}", err),
102            }
103        } else {
104            error!("Error processing packet: {}", ErrorKind::ReceivedDataToShort);
105        }
106        actions
107    }
108
109    fn process_event(
110        &mut self,
111        event: Self::SendEvent,
112        _time: Instant,
113    ) -> Vec<Action<Self::ReceiveEvent>> {
114        let mut actions = Vec::new();
115        let addr = self.remote_address;
116        if self.record_send() {
117            actions.push(Action::Emit(SocketEvent::Connect(addr)));
118        }
119
120        // Convert user packet to command
121        let channel_id = event.channel_id();
122        let ordering = event.order_guarantee();
123
124        match event.delivery_guarantee() {
125            DeliveryGuarantee::Reliable => {
126                // Use enqueue_reliable_data which handles fragmentation
127                // Reliable unordered when ordering is None, otherwise ordered
128                let ordered = !matches!(ordering, OrderingGuarantee::None);
129                self.enqueue_reliable_data(channel_id, event.payload_arc(), ordered);
130            }
131            DeliveryGuarantee::Unreliable => {
132                use bitfold_protocol::packet::OrderingGuarantee;
133
134                match ordering {
135                    OrderingGuarantee::Unsequenced => {
136                        // Unsequenced: prevents duplicates without ordering.
137                        // Chunk into multiple unsequenced commands if needed to fit MTU budget.
138                        let datagram_cap = std::cmp::min(
139                            self.current_fragment_size() as usize,
140                            self.config().receive_buffer_max_size,
141                        );
142                        let compression_overhead = match self.config().compression {
143                            bitfold_core::config::CompressionAlgorithm::Lz4 => 5,
144                            _ => 1,
145                        };
146                        let checksum_overhead = if self.config().use_checksums { 4 } else { 0 };
147                        let per_packet_overhead =
148                            1 /* command count */ + compression_overhead + checksum_overhead;
149                        let send_unsequenced_header =
150                            1 /* type */ + 1 /* channel */ + 2 /* unseq group */ + 2 /* len */; // = 6
151                        let max_payload_unseq = std::cmp::max(
152                            1,
153                            datagram_cap
154                                .saturating_sub(per_packet_overhead)
155                                .saturating_sub(2 /* len prefix */)
156                                .saturating_sub(send_unsequenced_header),
157                        );
158
159                        let base = bitfold_core::shared::SharedBytes::from_arc(event.payload_arc());
160                        let mut offset = 0usize;
161                        while offset < base.len() {
162                            let len = std::cmp::min(max_payload_unseq, base.len() - offset);
163                            let chunk = base.slice(offset, len);
164                            let unsequenced_group = self.next_unsequenced_group();
165                            self.enqueue_command(
166                                bitfold_protocol::command::ProtocolCommand::SendUnsequenced {
167                                    channel_id,
168                                    unsequenced_group,
169                                    data: chunk,
170                                },
171                            );
172                            offset += len;
173                        }
174                    }
175                    _ => {
176                        // Regular unreliable (no sequencing or ordering); allow fragmentation
177                        self.enqueue_unreliable_data(channel_id, event.payload_arc());
178                    }
179                }
180            }
181        }
182
183        // Flush commands immediately if within bandwidth, splitting into MTU-sized datagrams
184        while self.has_queued_commands() && self.can_send_within_bandwidth() {
185            let cap = std::cmp::min(
186                self.current_fragment_size() as usize,
187                self.config().receive_buffer_max_size,
188            );
189            match self.encode_queued_commands_bounded(cap) {
190                Ok(Some(bytes)) => {
191                    // Track bytes for bandwidth throttling
192                    self.record_bytes_sent(bytes.len() as u32);
193                    actions.push(Action::Send(bytes));
194                }
195                Ok(None) => break,
196                Err(e) => {
197                    error!("Error encoding queued commands: {:?}", e);
198                    break;
199                }
200            }
201        }
202        // If over bandwidth limit, keep commands queued for next window
203
204        actions
205    }
206
207    fn update(&mut self, time: Instant) -> Vec<Action<Self::ReceiveEvent>> {
208        let mut actions = Vec::new();
209
210        // Update bandwidth tracking window
211        self.update_bandwidth_window(time);
212
213        // Enqueue ping for keepalive if needed
214        if self.is_established() {
215            if let Some(heartbeat_interval) = self.config().heartbeat_interval {
216                // Only send heartbeat when both directions have been idle long enough.
217                if self.last_sent(time) >= heartbeat_interval
218                    && self.last_heard(time) >= heartbeat_interval
219                {
220                    // Use command-based Ping for keepalive
221                    self.enqueue_ping_command(time.elapsed().as_millis() as u32);
222                }
223            }
224        }
225
226        // Flush any queued commands (ACKs, Pongs, Pings, etc.) if within bandwidth,
227        // splitting into MTU-sized datagrams
228        while self.has_queued_commands() && self.can_send_within_bandwidth() {
229            let cap = std::cmp::min(
230                self.current_fragment_size() as usize,
231                self.config().receive_buffer_max_size,
232            );
233            match self.encode_queued_commands_bounded(cap) {
234                Ok(Some(bytes)) => {
235                    self.record_bytes_sent(bytes.len() as u32);
236                    actions.push(Action::Send(bytes));
237                }
238                Ok(None) => break,
239                Err(e) => {
240                    error!("Error encoding queued commands: {:?}", e);
241                    break;
242                }
243            }
244        }
245
246        // Application-level PMTU discovery & per-peer fragment size tuning
247        self.handle_pmtu(time);
248
249        actions
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use std::time::{Duration, Instant};
256
257    use bitfold_protocol::{command::ProtocolCommand, command_codec::CommandDecoder};
258
259    use super::*;
260    use crate::session::Session;
261
262    fn is_ping(bytes: &[u8]) -> bool {
263        // Decompress first (packets are now compressed by default)
264        if let Ok(decompressed) = CommandDecoder::decompress(bytes) {
265            // Check for command-based Ping only
266            if let Ok(packet) = CommandDecoder::decode_packet(&decompressed) {
267                return packet
268                    .commands
269                    .iter()
270                    .any(|cmd| matches!(cmd, ProtocolCommand::Ping { .. }));
271            }
272        }
273        false
274    }
275
276    #[test]
277    fn heartbeat_not_sent_when_recent_inbound() {
278        let mut cfg = bitfold_core::config::Config::default();
279        cfg.heartbeat_interval = Some(Duration::from_millis(50));
280        let start = Instant::now();
281
282        let mut conn = Peer::new("127.0.0.1:0".parse().unwrap(), &cfg, start);
283        // Mark connection as established
284        conn.record_send();
285        conn.record_recv();
286
287        // Simulate we haven't sent for >= interval but we have received recently
288        conn.last_sent = start - Duration::from_millis(55);
289        conn.last_heard = start - Duration::from_millis(10);
290
291        let actions = conn.update(start);
292        // Expect no ping send
293        assert!(actions.iter().all(|a| match a {
294            Action::Send(bytes) => !is_ping(bytes),
295            _ => true,
296        }));
297    }
298
299    #[test]
300    fn heartbeat_sent_when_bi_idle() {
301        let mut cfg = bitfold_core::config::Config::default();
302        cfg.heartbeat_interval = Some(Duration::from_millis(50));
303        cfg.use_checksums = false; // Disable checksums for this test to isolate heartbeat behavior
304        cfg.use_connection_handshake = false; // Disable handshake for this test to isolate heartbeat behavior
305        let start = Instant::now();
306
307        let mut conn = Peer::new("127.0.0.1:0".parse().unwrap(), &cfg, start);
308        // Mark connection as established
309        conn.record_send();
310        conn.record_recv();
311
312        // Both directions idle past interval
313        conn.last_sent = start - Duration::from_millis(60);
314        conn.last_heard = start - Duration::from_millis(60);
315
316        let actions = conn.update(start);
317        // Should send command-based Ping
318        assert!(actions.iter().any(|a| match a {
319            Action::Send(bytes) => is_ping(bytes),
320            _ => false,
321        }));
322    }
323
324    #[test]
325    fn incoming_bandwidth_limit_drops_excess_packets() {
326        // Build a small encoded packet from a client peer
327        let start = Instant::now();
328        let client_cfg = bitfold_core::config::Config::default();
329        let addr = "127.0.0.1:0".parse().unwrap();
330        let mut client = Peer::new(addr, &client_cfg, start);
331
332        // Queue a small unreliable packet and encode it
333        client.enqueue_unreliable_data(0, vec![1, 2, 3, 4, 5, 6, 7, 8].into());
334        let encoded = client.encode_queued_commands().unwrap();
335
336        // Configure server peer with incoming limit equal to one packet size
337        let mut server_cfg = bitfold_core::config::Config::default();
338        server_cfg.incoming_bandwidth_limit = encoded.len() as u32; // allow exactly one
339        let mut server = Peer::new(addr, &server_cfg, start);
340
341        // First packet should be processed (at least one Packet event emitted)
342        let actions1 = <Peer as Session>::process_packet(&mut server, &encoded, start);
343        assert!(actions1.iter().any(|a| matches!(a, Action::Emit(SocketEvent::Packet(_)))));
344
345        // Second packet within the same window should be dropped due to limit
346        let actions2 = <Peer as Session>::process_packet(&mut server, &encoded, start);
347        assert!(actions2.iter().all(|a| !matches!(a, Action::Emit(SocketEvent::Packet(_)))));
348    }
349}