1use crate::relay::{RelayError, RelayResult};
4use std::collections::VecDeque;
5use std::net::SocketAddr;
6use std::sync::{Arc, Mutex};
7use std::time::{Duration, Instant};
8use tokio::sync::mpsc;
9
10#[derive(Debug, Clone)]
12pub struct RelayConnectionConfig {
13 pub max_frame_size: usize,
15 pub buffer_size: usize,
17 pub connection_timeout: Duration,
19 pub keep_alive_interval: Duration,
21 pub bandwidth_limit: u64,
23}
24
25impl Default for RelayConnectionConfig {
26 fn default() -> Self {
27 Self {
28 max_frame_size: 65536, buffer_size: 1048576, connection_timeout: Duration::from_secs(300), keep_alive_interval: Duration::from_secs(30), bandwidth_limit: 1048576, }
34 }
35}
36
37#[derive(Debug, Clone)]
39pub enum RelayEvent {
40 ConnectionEstablished {
42 session_id: u32,
43 peer_addr: SocketAddr,
44 },
45 DataReceived { session_id: u32, data: Vec<u8> },
47 ConnectionTerminated { session_id: u32, reason: String },
49 Error {
51 session_id: Option<u32>,
52 error: RelayError,
53 },
54 BandwidthLimitExceeded {
56 session_id: u32,
57 current_usage: u64,
58 limit: u64,
59 },
60 KeepAlive { session_id: u32 },
62}
63
64#[derive(Debug, Clone)]
66pub enum RelayAction {
67 SendData { session_id: u32, data: Vec<u8> },
69 TerminateConnection { session_id: u32, reason: String },
71 UpdateBandwidthLimit { session_id: u32, new_limit: u64 },
73 SendKeepAlive { session_id: u32 },
75}
76
77#[derive(Debug)]
79pub struct RelayConnection {
80 session_id: u32,
82 peer_addr: SocketAddr,
84 config: RelayConnectionConfig,
86 state: Arc<Mutex<ConnectionState>>,
88 event_sender: mpsc::UnboundedSender<RelayEvent>,
90 action_receiver: mpsc::UnboundedReceiver<RelayAction>,
92}
93
94#[derive(Debug)]
96struct ConnectionState {
97 is_active: bool,
99 outgoing_queue: VecDeque<Vec<u8>>,
101 incoming_queue: VecDeque<Vec<u8>>,
103 buffer_usage: usize,
105 bandwidth_tracker: BandwidthTracker,
107 last_activity: Instant,
109 next_keep_alive: Instant,
111}
112
113#[derive(Debug)]
115struct BandwidthTracker {
116 bytes_sent: u64,
118 bytes_received: u64,
120 window_start: Instant,
122 window_duration: Duration,
124 limit: u64,
126}
127
128impl BandwidthTracker {
129 fn new(limit: u64) -> Self {
130 Self {
131 bytes_sent: 0,
132 bytes_received: 0,
133 window_start: Instant::now(),
134 window_duration: Duration::from_secs(1),
135 limit,
136 }
137 }
138
139 fn reset_if_needed(&mut self) {
140 let now = Instant::now();
141 if now.duration_since(self.window_start) >= self.window_duration {
142 self.bytes_sent = 0;
143 self.bytes_received = 0;
144 self.window_start = now;
145 }
146 }
147
148 fn can_send(&mut self, bytes: u64) -> bool {
149 self.reset_if_needed();
150 self.bytes_sent + bytes <= self.limit
151 }
152
153 fn record_sent(&mut self, bytes: u64) {
154 self.reset_if_needed();
155 self.bytes_sent += bytes;
156 }
157
158 fn record_received(&mut self, bytes: u64) {
159 self.reset_if_needed();
160 self.bytes_received += bytes;
161 }
162
163 fn current_usage(&mut self) -> u64 {
164 self.reset_if_needed();
165 self.bytes_sent + self.bytes_received
166 }
167}
168
169impl RelayConnection {
170 pub fn new(
172 session_id: u32,
173 peer_addr: SocketAddr,
174 config: RelayConnectionConfig,
175 event_sender: mpsc::UnboundedSender<RelayEvent>,
176 action_receiver: mpsc::UnboundedReceiver<RelayAction>,
177 ) -> Self {
178 let now = Instant::now();
179 let state = ConnectionState {
180 is_active: true,
181 outgoing_queue: VecDeque::new(),
182 incoming_queue: VecDeque::new(),
183 buffer_usage: 0,
184 bandwidth_tracker: BandwidthTracker::new(config.bandwidth_limit),
185 last_activity: now,
186 next_keep_alive: now + config.keep_alive_interval,
187 };
188
189 Self {
190 session_id,
191 peer_addr,
192 config,
193 state: Arc::new(Mutex::new(state)),
194 event_sender,
195 action_receiver,
196 }
197 }
198
199 pub fn session_id(&self) -> u32 {
201 self.session_id
202 }
203
204 pub fn peer_addr(&self) -> SocketAddr {
206 self.peer_addr
207 }
208
209 pub fn is_active(&self) -> bool {
211 let state = self.state.lock().unwrap();
212 state.is_active
213 }
214
215 pub fn send_data(&self, data: Vec<u8>) -> RelayResult<()> {
217 if data.len() > self.config.max_frame_size {
218 return Err(RelayError::ProtocolError {
219 frame_type: 0x46, reason: format!(
221 "Data size {} exceeds maximum {}",
222 data.len(),
223 self.config.max_frame_size
224 ),
225 });
226 }
227
228 let mut state = self.state.lock().unwrap();
229
230 if !state.is_active {
231 return Err(RelayError::SessionError {
232 session_id: Some(self.session_id),
233 kind: crate::relay::error::SessionErrorKind::Terminated,
234 });
235 }
236
237 if !state.bandwidth_tracker.can_send(data.len() as u64) {
239 let current_usage = state.bandwidth_tracker.current_usage();
240 return Err(RelayError::SessionError {
241 session_id: Some(self.session_id),
242 kind: crate::relay::error::SessionErrorKind::BandwidthExceeded {
243 used: current_usage,
244 limit: self.config.bandwidth_limit,
245 },
246 });
247 }
248
249 if state.buffer_usage + data.len() > self.config.buffer_size {
251 return Err(RelayError::ResourceExhausted {
252 resource_type: "buffer".to_string(),
253 current_usage: state.buffer_usage as u64,
254 limit: self.config.buffer_size as u64,
255 });
256 }
257
258 state.bandwidth_tracker.record_sent(data.len() as u64);
260 state.buffer_usage += data.len();
261 state.outgoing_queue.push_back(data.clone());
262 state.last_activity = Instant::now();
263
264 let _ = self.event_sender.send(RelayEvent::DataReceived {
266 session_id: self.session_id,
267 data,
268 });
269
270 Ok(())
271 }
272
273 pub fn receive_data(&self, data: Vec<u8>) -> RelayResult<()> {
275 let mut state = self.state.lock().unwrap();
276
277 if !state.is_active {
278 return Err(RelayError::SessionError {
279 session_id: Some(self.session_id),
280 kind: crate::relay::error::SessionErrorKind::Terminated,
281 });
282 }
283
284 if state.buffer_usage + data.len() > self.config.buffer_size {
286 return Err(RelayError::ResourceExhausted {
287 resource_type: "buffer".to_string(),
288 current_usage: state.buffer_usage as u64,
289 limit: self.config.buffer_size as u64,
290 });
291 }
292
293 state.bandwidth_tracker.record_received(data.len() as u64);
295 state.buffer_usage += data.len();
296 state.incoming_queue.push_back(data.clone());
297 state.last_activity = Instant::now();
298
299 let _ = self.event_sender.send(RelayEvent::DataReceived {
301 session_id: self.session_id,
302 data,
303 });
304
305 Ok(())
306 }
307
308 pub fn next_outgoing(&self) -> Option<Vec<u8>> {
310 let mut state = self.state.lock().unwrap();
311 if let Some(data) = state.outgoing_queue.pop_front() {
312 state.buffer_usage = state.buffer_usage.saturating_sub(data.len());
313 Some(data)
314 } else {
315 None
316 }
317 }
318
319 pub fn next_incoming(&self) -> Option<Vec<u8>> {
321 let mut state = self.state.lock().unwrap();
322 if let Some(data) = state.incoming_queue.pop_front() {
323 state.buffer_usage = state.buffer_usage.saturating_sub(data.len());
324 Some(data)
325 } else {
326 None
327 }
328 }
329
330 pub fn check_timeout(&self) -> RelayResult<()> {
332 let state = self.state.lock().unwrap();
333 let now = Instant::now();
334
335 if now.duration_since(state.last_activity) > self.config.connection_timeout {
336 return Err(RelayError::SessionError {
337 session_id: Some(self.session_id),
338 kind: crate::relay::error::SessionErrorKind::Expired,
339 });
340 }
341
342 Ok(())
343 }
344
345 pub fn should_send_keep_alive(&self) -> bool {
347 let state = self.state.lock().unwrap();
348 Instant::now() >= state.next_keep_alive
349 }
350
351 pub fn send_keep_alive(&self) -> RelayResult<()> {
353 let mut state = self.state.lock().unwrap();
354 state.next_keep_alive = Instant::now() + self.config.keep_alive_interval;
355
356 let _ = self.event_sender.send(RelayEvent::KeepAlive {
357 session_id: self.session_id,
358 });
359
360 Ok(())
361 }
362
363 pub fn terminate(&self, reason: String) -> RelayResult<()> {
365 let mut state = self.state.lock().unwrap();
366 state.is_active = false;
367
368 let _ = self.event_sender.send(RelayEvent::ConnectionTerminated {
369 session_id: self.session_id,
370 reason: reason.clone(),
371 });
372
373 Ok(())
374 }
375
376 pub fn get_stats(&self) -> ConnectionStats {
378 let state = self.state.lock().unwrap();
379 ConnectionStats {
380 session_id: self.session_id,
381 peer_addr: self.peer_addr,
382 is_active: state.is_active,
383 bytes_sent: state.bandwidth_tracker.bytes_sent,
384 bytes_received: state.bandwidth_tracker.bytes_received,
385 buffer_usage: state.buffer_usage,
386 outgoing_queue_size: state.outgoing_queue.len(),
387 incoming_queue_size: state.incoming_queue.len(),
388 last_activity: state.last_activity,
389 }
390 }
391}
392
393#[derive(Debug, Clone)]
395pub struct ConnectionStats {
396 pub session_id: u32,
397 pub peer_addr: SocketAddr,
398 pub is_active: bool,
399 pub bytes_sent: u64,
400 pub bytes_received: u64,
401 pub buffer_usage: usize,
402 pub outgoing_queue_size: usize,
403 pub incoming_queue_size: usize,
404 pub last_activity: Instant,
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410 use std::net::{IpAddr, Ipv4Addr};
411 use tokio::sync::mpsc;
412
413 fn test_addr() -> SocketAddr {
414 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080)
415 }
416
417 #[test]
418 fn test_relay_connection_creation() {
419 let (event_tx, _event_rx) = mpsc::unbounded_channel();
420 let (_action_tx, action_rx) = mpsc::unbounded_channel();
421
422 let connection = RelayConnection::new(
423 123,
424 test_addr(),
425 RelayConnectionConfig::default(),
426 event_tx,
427 action_rx,
428 );
429
430 assert_eq!(connection.session_id(), 123);
431 assert_eq!(connection.peer_addr(), test_addr());
432 assert!(connection.is_active());
433 }
434
435 #[test]
436 fn test_send_data_within_limits() {
437 let (event_tx, _event_rx) = mpsc::unbounded_channel();
438 let (_action_tx, action_rx) = mpsc::unbounded_channel();
439
440 let connection = RelayConnection::new(
441 123,
442 test_addr(),
443 RelayConnectionConfig::default(),
444 event_tx,
445 action_rx,
446 );
447
448 let data = vec![1, 2, 3, 4];
449 assert!(connection.send_data(data.clone()).is_ok());
450
451 assert_eq!(connection.next_outgoing(), Some(data));
453 }
454
455 #[test]
456 fn test_send_data_exceeds_frame_size() {
457 let (event_tx, _event_rx) = mpsc::unbounded_channel();
458 let (_action_tx, action_rx) = mpsc::unbounded_channel();
459
460 let mut config = RelayConnectionConfig::default();
461 config.max_frame_size = 10;
462
463 let connection = RelayConnection::new(123, test_addr(), config, event_tx, action_rx);
464
465 let large_data = vec![0; 20];
466 assert!(connection.send_data(large_data).is_err());
467 }
468
469 #[test]
470 fn test_bandwidth_limiting() {
471 let (event_tx, _event_rx) = mpsc::unbounded_channel();
472 let (_action_tx, action_rx) = mpsc::unbounded_channel();
473
474 let mut config = RelayConnectionConfig::default();
475 config.bandwidth_limit = 100; let connection = RelayConnection::new(123, test_addr(), config, event_tx, action_rx);
478
479 assert!(connection.send_data(vec![0; 50]).is_ok());
481
482 assert!(connection.send_data(vec![0; 60]).is_err());
484 }
485
486 #[test]
487 fn test_buffer_size_limiting() {
488 let (event_tx, _event_rx) = mpsc::unbounded_channel();
489 let (_action_tx, action_rx) = mpsc::unbounded_channel();
490
491 let mut config = RelayConnectionConfig::default();
492 config.buffer_size = 100; let connection = RelayConnection::new(123, test_addr(), config, event_tx, action_rx);
495
496 assert!(connection.send_data(vec![0; 80]).is_ok());
498
499 assert!(connection.send_data(vec![0; 30]).is_err());
501 }
502
503 #[test]
504 fn test_connection_termination() {
505 let (event_tx, _event_rx) = mpsc::unbounded_channel();
506 let (_action_tx, action_rx) = mpsc::unbounded_channel();
507
508 let connection = RelayConnection::new(
509 123,
510 test_addr(),
511 RelayConnectionConfig::default(),
512 event_tx,
513 action_rx,
514 );
515
516 assert!(connection.is_active());
517
518 let reason = "Test termination".to_string();
519 assert!(connection.terminate(reason.clone()).is_ok());
520
521 assert!(!connection.is_active());
522
523 assert!(connection.send_data(vec![1, 2, 3]).is_err());
525 }
526
527 #[test]
528 fn test_keep_alive() {
529 let (event_tx, _event_rx) = mpsc::unbounded_channel();
530 let (_action_tx, action_rx) = mpsc::unbounded_channel();
531
532 let mut config = RelayConnectionConfig::default();
533 config.keep_alive_interval = Duration::from_millis(1);
534
535 let connection = RelayConnection::new(123, test_addr(), config, event_tx, action_rx);
536
537 assert!(!connection.should_send_keep_alive());
539
540 std::thread::sleep(Duration::from_millis(2));
542
543 assert!(connection.should_send_keep_alive());
545
546 assert!(connection.send_keep_alive().is_ok());
548
549 assert!(!connection.should_send_keep_alive());
551 }
552
553 #[test]
554 fn test_connection_stats() {
555 let (event_tx, _event_rx) = mpsc::unbounded_channel();
556 let (_action_tx, action_rx) = mpsc::unbounded_channel();
557
558 let connection = RelayConnection::new(
559 123,
560 test_addr(),
561 RelayConnectionConfig::default(),
562 event_tx,
563 action_rx,
564 );
565
566 connection.send_data(vec![1, 2, 3]).unwrap();
568 connection.receive_data(vec![4, 5, 6, 7]).unwrap();
569
570 let stats = connection.get_stats();
571 assert_eq!(stats.session_id, 123);
572 assert_eq!(stats.peer_addr, test_addr());
573 assert!(stats.is_active);
574 assert_eq!(stats.bytes_sent, 3);
575 assert_eq!(stats.bytes_received, 4);
576 assert_eq!(stats.outgoing_queue_size, 1);
577 assert_eq!(stats.incoming_queue_size, 1);
578 }
579}