1use std::{net::SocketAddr, time::Instant};
2
3use bitfold_core::error::ErrorKind;
4use bitfold_peer::{Peer, PeerState};
5use bitfold_protocol::packet::{DeliveryGuarantee, OrderingGuarantee, Packet};
6use tracing::error;
7
8use super::{
9 event_types::{Action, SocketEvent},
10 session::{Session, SessionEventAddress},
11};
12
13impl SessionEventAddress for SocketEvent {
15 fn address(&self) -> SocketAddr {
17 match self {
18 SocketEvent::Packet(packet) => packet.addr(),
19 SocketEvent::Connect(addr) => *addr,
20 SocketEvent::Timeout(addr) => *addr,
21 SocketEvent::Disconnect(addr) => *addr,
22 }
23 }
24}
25
26impl SessionEventAddress for Packet {
28 fn address(&self) -> SocketAddr {
30 self.addr()
31 }
32}
33
34impl Session for Peer {
35 type SendEvent = Packet;
36 type ReceiveEvent = SocketEvent;
37
38 fn create_session(
39 config: &bitfold_core::config::Config,
40 address: SocketAddr,
41 time: Instant,
42 ) -> Peer {
43 Peer::new(address, config, time)
44 }
45
46 fn is_established(&self) -> bool {
47 self.is_established()
48 }
49
50 fn should_drop(&mut self, time: Instant) -> (bool, Vec<Action<Self::ReceiveEvent>>) {
51 let mut actions = Vec::new();
52
53 if self.state() == PeerState::Zombie {
55 actions.push(Action::Emit(SocketEvent::Disconnect(self.remote_address)));
56 return (true, actions);
57 }
58
59 let should_drop = self.packets_in_flight() > self.config().max_packets_in_flight
61 || self.last_heard(time) >= self.config().idle_connection_timeout;
62
63 if should_drop {
64 actions.push(Action::Emit(SocketEvent::Timeout(self.remote_address)));
65 if self.is_established() {
66 actions.push(Action::Emit(SocketEvent::Disconnect(self.remote_address)));
67 }
68 }
69 (should_drop, actions)
70 }
71
72 fn process_packet(&mut self, payload: &[u8], time: Instant) -> Vec<Action<Self::ReceiveEvent>> {
73 let mut actions = Vec::new();
74 if !payload.is_empty() {
75 self.update_bandwidth_window(time);
77 if !self.can_receive_within_bandwidth() {
78 tracing::warn!(
80 "Dropping packet ({} bytes) from {} due to incoming bandwidth limit (utilization {:.2})",
81 payload.len(),
82 self.remote_address,
83 self.incoming_bandwidth_utilization()
84 );
85 return actions; }
87
88 self.record_bytes_received(payload.len() as u32);
90
91 match self.process_command_packet(payload, time) {
93 Ok(packets) => {
94 if self.record_recv() {
95 actions.push(Action::Emit(SocketEvent::Connect(self.remote_address)));
96 }
97 for incoming in packets {
98 actions.push(Action::Emit(SocketEvent::Packet(incoming.0)));
99 }
100 }
101 Err(err) => error!("Error occurred processing command packet: {:?}", err),
102 }
103 } else {
104 error!("Error processing packet: {}", ErrorKind::ReceivedDataToShort);
105 }
106 actions
107 }
108
109 fn process_event(
110 &mut self,
111 event: Self::SendEvent,
112 _time: Instant,
113 ) -> Vec<Action<Self::ReceiveEvent>> {
114 let mut actions = Vec::new();
115 let addr = self.remote_address;
116 if self.record_send() {
117 actions.push(Action::Emit(SocketEvent::Connect(addr)));
118 }
119
120 let channel_id = event.channel_id();
122 let ordering = event.order_guarantee();
123
124 match event.delivery_guarantee() {
125 DeliveryGuarantee::Reliable => {
126 let ordered = !matches!(ordering, OrderingGuarantee::None);
129 self.enqueue_reliable_data(channel_id, event.payload_arc(), ordered);
130 }
131 DeliveryGuarantee::Unreliable => {
132 use bitfold_protocol::packet::OrderingGuarantee;
133
134 match ordering {
135 OrderingGuarantee::Unsequenced => {
136 let datagram_cap = std::cmp::min(
139 self.current_fragment_size() as usize,
140 self.config().receive_buffer_max_size,
141 );
142 let compression_overhead = match self.config().compression {
143 bitfold_core::config::CompressionAlgorithm::Lz4 => 5,
144 _ => 1,
145 };
146 let checksum_overhead = if self.config().use_checksums { 4 } else { 0 };
147 let per_packet_overhead =
148 1 + compression_overhead + checksum_overhead;
149 let send_unsequenced_header =
150 1 + 1 + 2 + 2 ; let max_payload_unseq = std::cmp::max(
152 1,
153 datagram_cap
154 .saturating_sub(per_packet_overhead)
155 .saturating_sub(2 )
156 .saturating_sub(send_unsequenced_header),
157 );
158
159 let base = bitfold_core::shared::SharedBytes::from_arc(event.payload_arc());
160 let mut offset = 0usize;
161 while offset < base.len() {
162 let len = std::cmp::min(max_payload_unseq, base.len() - offset);
163 let chunk = base.slice(offset, len);
164 let unsequenced_group = self.next_unsequenced_group();
165 self.enqueue_command(
166 bitfold_protocol::command::ProtocolCommand::SendUnsequenced {
167 channel_id,
168 unsequenced_group,
169 data: chunk,
170 },
171 );
172 offset += len;
173 }
174 }
175 _ => {
176 self.enqueue_unreliable_data(channel_id, event.payload_arc());
178 }
179 }
180 }
181 }
182
183 while self.has_queued_commands() && self.can_send_within_bandwidth() {
185 let cap = std::cmp::min(
186 self.current_fragment_size() as usize,
187 self.config().receive_buffer_max_size,
188 );
189 match self.encode_queued_commands_bounded(cap) {
190 Ok(Some(bytes)) => {
191 self.record_bytes_sent(bytes.len() as u32);
193 actions.push(Action::Send(bytes));
194 }
195 Ok(None) => break,
196 Err(e) => {
197 error!("Error encoding queued commands: {:?}", e);
198 break;
199 }
200 }
201 }
202 actions
205 }
206
207 fn update(&mut self, time: Instant) -> Vec<Action<Self::ReceiveEvent>> {
208 let mut actions = Vec::new();
209
210 self.update_bandwidth_window(time);
212
213 if self.is_established() {
215 if let Some(heartbeat_interval) = self.config().heartbeat_interval {
216 if self.last_sent(time) >= heartbeat_interval
218 && self.last_heard(time) >= heartbeat_interval
219 {
220 self.enqueue_ping_command(time.elapsed().as_millis() as u32);
222 }
223 }
224 }
225
226 while self.has_queued_commands() && self.can_send_within_bandwidth() {
229 let cap = std::cmp::min(
230 self.current_fragment_size() as usize,
231 self.config().receive_buffer_max_size,
232 );
233 match self.encode_queued_commands_bounded(cap) {
234 Ok(Some(bytes)) => {
235 self.record_bytes_sent(bytes.len() as u32);
236 actions.push(Action::Send(bytes));
237 }
238 Ok(None) => break,
239 Err(e) => {
240 error!("Error encoding queued commands: {:?}", e);
241 break;
242 }
243 }
244 }
245
246 self.handle_pmtu(time);
248
249 actions
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use std::time::{Duration, Instant};
256
257 use bitfold_protocol::{command::ProtocolCommand, command_codec::CommandDecoder};
258
259 use super::*;
260 use crate::session::Session;
261
262 fn is_ping(bytes: &[u8]) -> bool {
263 if let Ok(decompressed) = CommandDecoder::decompress(bytes) {
265 if let Ok(packet) = CommandDecoder::decode_packet(&decompressed) {
267 return packet
268 .commands
269 .iter()
270 .any(|cmd| matches!(cmd, ProtocolCommand::Ping { .. }));
271 }
272 }
273 false
274 }
275
276 #[test]
277 fn heartbeat_not_sent_when_recent_inbound() {
278 let mut cfg = bitfold_core::config::Config::default();
279 cfg.heartbeat_interval = Some(Duration::from_millis(50));
280 let start = Instant::now();
281
282 let mut conn = Peer::new("127.0.0.1:0".parse().unwrap(), &cfg, start);
283 conn.record_send();
285 conn.record_recv();
286
287 conn.last_sent = start - Duration::from_millis(55);
289 conn.last_heard = start - Duration::from_millis(10);
290
291 let actions = conn.update(start);
292 assert!(actions.iter().all(|a| match a {
294 Action::Send(bytes) => !is_ping(bytes),
295 _ => true,
296 }));
297 }
298
299 #[test]
300 fn heartbeat_sent_when_bi_idle() {
301 let mut cfg = bitfold_core::config::Config::default();
302 cfg.heartbeat_interval = Some(Duration::from_millis(50));
303 cfg.use_checksums = false; cfg.use_connection_handshake = false; let start = Instant::now();
306
307 let mut conn = Peer::new("127.0.0.1:0".parse().unwrap(), &cfg, start);
308 conn.record_send();
310 conn.record_recv();
311
312 conn.last_sent = start - Duration::from_millis(60);
314 conn.last_heard = start - Duration::from_millis(60);
315
316 let actions = conn.update(start);
317 assert!(actions.iter().any(|a| match a {
319 Action::Send(bytes) => is_ping(bytes),
320 _ => false,
321 }));
322 }
323
324 #[test]
325 fn incoming_bandwidth_limit_drops_excess_packets() {
326 let start = Instant::now();
328 let client_cfg = bitfold_core::config::Config::default();
329 let addr = "127.0.0.1:0".parse().unwrap();
330 let mut client = Peer::new(addr, &client_cfg, start);
331
332 client.enqueue_unreliable_data(0, vec![1, 2, 3, 4, 5, 6, 7, 8].into());
334 let encoded = client.encode_queued_commands().unwrap();
335
336 let mut server_cfg = bitfold_core::config::Config::default();
338 server_cfg.incoming_bandwidth_limit = encoded.len() as u32; let mut server = Peer::new(addr, &server_cfg, start);
340
341 let actions1 = <Peer as Session>::process_packet(&mut server, &encoded, start);
343 assert!(actions1.iter().any(|a| matches!(a, Action::Emit(SocketEvent::Packet(_)))));
344
345 let actions2 = <Peer as Session>::process_packet(&mut server, &encoded, start);
347 assert!(actions2.iter().all(|a| !matches!(a, Action::Emit(SocketEvent::Packet(_)))));
348 }
349}