1use crate::relay::{RelayError, RelayResult};
4use std::collections::VecDeque;
5use std::net::SocketAddr;
6use std::sync::{Arc, Mutex};
7use std::time::{Duration, Instant};
8use tokio::sync::mpsc;
9
10#[derive(Debug, Clone)]
12pub struct RelayConnectionConfig {
13 pub max_frame_size: usize,
15 pub buffer_size: usize,
17 pub connection_timeout: Duration,
19 pub keep_alive_interval: Duration,
21 pub bandwidth_limit: u64,
23}
24
25impl Default for RelayConnectionConfig {
26 fn default() -> Self {
27 Self {
28 max_frame_size: 65536, buffer_size: 1048576, connection_timeout: Duration::from_secs(300), keep_alive_interval: Duration::from_secs(30), bandwidth_limit: 1048576, }
34 }
35}
36
37#[derive(Debug, Clone)]
39pub enum RelayEvent {
40 ConnectionEstablished {
42 session_id: u32,
43 peer_addr: SocketAddr,
44 },
45 DataReceived {
47 session_id: u32,
48 data: Vec<u8>,
49 },
50 ConnectionTerminated {
52 session_id: u32,
53 reason: String,
54 },
55 Error {
57 session_id: Option<u32>,
58 error: RelayError,
59 },
60 BandwidthLimitExceeded {
62 session_id: u32,
63 current_usage: u64,
64 limit: u64,
65 },
66 KeepAlive {
68 session_id: u32,
69 },
70}
71
72#[derive(Debug, Clone)]
74pub enum RelayAction {
75 SendData {
77 session_id: u32,
78 data: Vec<u8>,
79 },
80 TerminateConnection {
82 session_id: u32,
83 reason: String,
84 },
85 UpdateBandwidthLimit {
87 session_id: u32,
88 new_limit: u64,
89 },
90 SendKeepAlive {
92 session_id: u32,
93 },
94}
95
96#[derive(Debug)]
98pub struct RelayConnection {
99 session_id: u32,
101 peer_addr: SocketAddr,
103 config: RelayConnectionConfig,
105 state: Arc<Mutex<ConnectionState>>,
107 event_sender: mpsc::UnboundedSender<RelayEvent>,
109 action_receiver: mpsc::UnboundedReceiver<RelayAction>,
111}
112
113#[derive(Debug)]
115struct ConnectionState {
116 is_active: bool,
118 outgoing_queue: VecDeque<Vec<u8>>,
120 incoming_queue: VecDeque<Vec<u8>>,
122 buffer_usage: usize,
124 bandwidth_tracker: BandwidthTracker,
126 last_activity: Instant,
128 next_keep_alive: Instant,
130}
131
132#[derive(Debug)]
134struct BandwidthTracker {
135 bytes_sent: u64,
137 bytes_received: u64,
139 window_start: Instant,
141 window_duration: Duration,
143 limit: u64,
145}
146
147impl BandwidthTracker {
148 fn new(limit: u64) -> Self {
149 Self {
150 bytes_sent: 0,
151 bytes_received: 0,
152 window_start: Instant::now(),
153 window_duration: Duration::from_secs(1),
154 limit,
155 }
156 }
157
158 fn reset_if_needed(&mut self) {
159 let now = Instant::now();
160 if now.duration_since(self.window_start) >= self.window_duration {
161 self.bytes_sent = 0;
162 self.bytes_received = 0;
163 self.window_start = now;
164 }
165 }
166
167 fn can_send(&mut self, bytes: u64) -> bool {
168 self.reset_if_needed();
169 self.bytes_sent + bytes <= self.limit
170 }
171
172 fn record_sent(&mut self, bytes: u64) {
173 self.reset_if_needed();
174 self.bytes_sent += bytes;
175 }
176
177 fn record_received(&mut self, bytes: u64) {
178 self.reset_if_needed();
179 self.bytes_received += bytes;
180 }
181
182 fn current_usage(&mut self) -> u64 {
183 self.reset_if_needed();
184 self.bytes_sent + self.bytes_received
185 }
186}
187
188impl RelayConnection {
189 pub fn new(
191 session_id: u32,
192 peer_addr: SocketAddr,
193 config: RelayConnectionConfig,
194 event_sender: mpsc::UnboundedSender<RelayEvent>,
195 action_receiver: mpsc::UnboundedReceiver<RelayAction>,
196 ) -> Self {
197 let now = Instant::now();
198 let state = ConnectionState {
199 is_active: true,
200 outgoing_queue: VecDeque::new(),
201 incoming_queue: VecDeque::new(),
202 buffer_usage: 0,
203 bandwidth_tracker: BandwidthTracker::new(config.bandwidth_limit),
204 last_activity: now,
205 next_keep_alive: now + config.keep_alive_interval,
206 };
207
208 Self {
209 session_id,
210 peer_addr,
211 config,
212 state: Arc::new(Mutex::new(state)),
213 event_sender,
214 action_receiver,
215 }
216 }
217
218 pub fn session_id(&self) -> u32 {
220 self.session_id
221 }
222
223 pub fn peer_addr(&self) -> SocketAddr {
225 self.peer_addr
226 }
227
228 pub fn is_active(&self) -> bool {
230 let state = self.state.lock().unwrap();
231 state.is_active
232 }
233
234 pub fn send_data(&self, data: Vec<u8>) -> RelayResult<()> {
236 if data.len() > self.config.max_frame_size {
237 return Err(RelayError::ProtocolError {
238 frame_type: 0x46, reason: format!("Data size {} exceeds maximum {}", data.len(), self.config.max_frame_size),
240 });
241 }
242
243 let mut state = self.state.lock().unwrap();
244
245 if !state.is_active {
246 return Err(RelayError::SessionError {
247 session_id: Some(self.session_id),
248 kind: crate::relay::error::SessionErrorKind::Terminated,
249 });
250 }
251
252 if !state.bandwidth_tracker.can_send(data.len() as u64) {
254 let current_usage = state.bandwidth_tracker.current_usage();
255 return Err(RelayError::SessionError {
256 session_id: Some(self.session_id),
257 kind: crate::relay::error::SessionErrorKind::BandwidthExceeded {
258 used: current_usage,
259 limit: self.config.bandwidth_limit,
260 },
261 });
262 }
263
264 if state.buffer_usage + data.len() > self.config.buffer_size {
266 return Err(RelayError::ResourceExhausted {
267 resource_type: "buffer".to_string(),
268 current_usage: state.buffer_usage as u64,
269 limit: self.config.buffer_size as u64,
270 });
271 }
272
273 state.bandwidth_tracker.record_sent(data.len() as u64);
275 state.buffer_usage += data.len();
276 state.outgoing_queue.push_back(data.clone());
277 state.last_activity = Instant::now();
278
279 let _ = self.event_sender.send(RelayEvent::DataReceived {
281 session_id: self.session_id,
282 data,
283 });
284
285 Ok(())
286 }
287
288 pub fn receive_data(&self, data: Vec<u8>) -> RelayResult<()> {
290 let mut state = self.state.lock().unwrap();
291
292 if !state.is_active {
293 return Err(RelayError::SessionError {
294 session_id: Some(self.session_id),
295 kind: crate::relay::error::SessionErrorKind::Terminated,
296 });
297 }
298
299 if state.buffer_usage + data.len() > self.config.buffer_size {
301 return Err(RelayError::ResourceExhausted {
302 resource_type: "buffer".to_string(),
303 current_usage: state.buffer_usage as u64,
304 limit: self.config.buffer_size as u64,
305 });
306 }
307
308 state.bandwidth_tracker.record_received(data.len() as u64);
310 state.buffer_usage += data.len();
311 state.incoming_queue.push_back(data.clone());
312 state.last_activity = Instant::now();
313
314 let _ = self.event_sender.send(RelayEvent::DataReceived {
316 session_id: self.session_id,
317 data,
318 });
319
320 Ok(())
321 }
322
323 pub fn next_outgoing(&self) -> Option<Vec<u8>> {
325 let mut state = self.state.lock().unwrap();
326 if let Some(data) = state.outgoing_queue.pop_front() {
327 state.buffer_usage = state.buffer_usage.saturating_sub(data.len());
328 Some(data)
329 } else {
330 None
331 }
332 }
333
334 pub fn next_incoming(&self) -> Option<Vec<u8>> {
336 let mut state = self.state.lock().unwrap();
337 if let Some(data) = state.incoming_queue.pop_front() {
338 state.buffer_usage = state.buffer_usage.saturating_sub(data.len());
339 Some(data)
340 } else {
341 None
342 }
343 }
344
345 pub fn check_timeout(&self) -> RelayResult<()> {
347 let state = self.state.lock().unwrap();
348 let now = Instant::now();
349
350 if now.duration_since(state.last_activity) > self.config.connection_timeout {
351 return Err(RelayError::SessionError {
352 session_id: Some(self.session_id),
353 kind: crate::relay::error::SessionErrorKind::Expired,
354 });
355 }
356
357 Ok(())
358 }
359
360 pub fn should_send_keep_alive(&self) -> bool {
362 let state = self.state.lock().unwrap();
363 Instant::now() >= state.next_keep_alive
364 }
365
366 pub fn send_keep_alive(&self) -> RelayResult<()> {
368 let mut state = self.state.lock().unwrap();
369 state.next_keep_alive = Instant::now() + self.config.keep_alive_interval;
370
371 let _ = self.event_sender.send(RelayEvent::KeepAlive {
372 session_id: self.session_id,
373 });
374
375 Ok(())
376 }
377
378 pub fn terminate(&self, reason: String) -> RelayResult<()> {
380 let mut state = self.state.lock().unwrap();
381 state.is_active = false;
382
383 let _ = self.event_sender.send(RelayEvent::ConnectionTerminated {
384 session_id: self.session_id,
385 reason: reason.clone(),
386 });
387
388 Ok(())
389 }
390
391 pub fn get_stats(&self) -> ConnectionStats {
393 let state = self.state.lock().unwrap();
394 ConnectionStats {
395 session_id: self.session_id,
396 peer_addr: self.peer_addr,
397 is_active: state.is_active,
398 bytes_sent: state.bandwidth_tracker.bytes_sent,
399 bytes_received: state.bandwidth_tracker.bytes_received,
400 buffer_usage: state.buffer_usage,
401 outgoing_queue_size: state.outgoing_queue.len(),
402 incoming_queue_size: state.incoming_queue.len(),
403 last_activity: state.last_activity,
404 }
405 }
406}
407
408#[derive(Debug, Clone)]
410pub struct ConnectionStats {
411 pub session_id: u32,
412 pub peer_addr: SocketAddr,
413 pub is_active: bool,
414 pub bytes_sent: u64,
415 pub bytes_received: u64,
416 pub buffer_usage: usize,
417 pub outgoing_queue_size: usize,
418 pub incoming_queue_size: usize,
419 pub last_activity: Instant,
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425 use std::net::{IpAddr, Ipv4Addr};
426 use tokio::sync::mpsc;
427
428 fn test_addr() -> SocketAddr {
429 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080)
430 }
431
432 #[test]
433 fn test_relay_connection_creation() {
434 let (event_tx, _event_rx) = mpsc::unbounded_channel();
435 let (_action_tx, action_rx) = mpsc::unbounded_channel();
436
437 let connection = RelayConnection::new(
438 123,
439 test_addr(),
440 RelayConnectionConfig::default(),
441 event_tx,
442 action_rx,
443 );
444
445 assert_eq!(connection.session_id(), 123);
446 assert_eq!(connection.peer_addr(), test_addr());
447 assert!(connection.is_active());
448 }
449
450 #[test]
451 fn test_send_data_within_limits() {
452 let (event_tx, mut event_rx) = mpsc::unbounded_channel();
453 let (_action_tx, action_rx) = mpsc::unbounded_channel();
454
455 let connection = RelayConnection::new(
456 123,
457 test_addr(),
458 RelayConnectionConfig::default(),
459 event_tx,
460 action_rx,
461 );
462
463 let data = vec![1, 2, 3, 4];
464 assert!(connection.send_data(data.clone()).is_ok());
465
466 assert_eq!(connection.next_outgoing(), Some(data));
468 }
469
470 #[test]
471 fn test_send_data_exceeds_frame_size() {
472 let (event_tx, _event_rx) = mpsc::unbounded_channel();
473 let (_action_tx, action_rx) = mpsc::unbounded_channel();
474
475 let mut config = RelayConnectionConfig::default();
476 config.max_frame_size = 10;
477
478 let connection = RelayConnection::new(
479 123,
480 test_addr(),
481 config,
482 event_tx,
483 action_rx,
484 );
485
486 let large_data = vec![0; 20];
487 assert!(connection.send_data(large_data).is_err());
488 }
489
490 #[test]
491 fn test_bandwidth_limiting() {
492 let (event_tx, _event_rx) = mpsc::unbounded_channel();
493 let (_action_tx, action_rx) = mpsc::unbounded_channel();
494
495 let mut config = RelayConnectionConfig::default();
496 config.bandwidth_limit = 100; let connection = RelayConnection::new(
499 123,
500 test_addr(),
501 config,
502 event_tx,
503 action_rx,
504 );
505
506 assert!(connection.send_data(vec![0; 50]).is_ok());
508
509 assert!(connection.send_data(vec![0; 60]).is_err());
511 }
512
513 #[test]
514 fn test_buffer_size_limiting() {
515 let (event_tx, _event_rx) = mpsc::unbounded_channel();
516 let (_action_tx, action_rx) = mpsc::unbounded_channel();
517
518 let mut config = RelayConnectionConfig::default();
519 config.buffer_size = 100; let connection = RelayConnection::new(
522 123,
523 test_addr(),
524 config,
525 event_tx,
526 action_rx,
527 );
528
529 assert!(connection.send_data(vec![0; 80]).is_ok());
531
532 assert!(connection.send_data(vec![0; 30]).is_err());
534 }
535
536 #[test]
537 fn test_connection_termination() {
538 let (event_tx, mut event_rx) = mpsc::unbounded_channel();
539 let (_action_tx, action_rx) = mpsc::unbounded_channel();
540
541 let connection = RelayConnection::new(
542 123,
543 test_addr(),
544 RelayConnectionConfig::default(),
545 event_tx,
546 action_rx,
547 );
548
549 assert!(connection.is_active());
550
551 let reason = "Test termination".to_string();
552 assert!(connection.terminate(reason.clone()).is_ok());
553
554 assert!(!connection.is_active());
555
556 assert!(connection.send_data(vec![1, 2, 3]).is_err());
558 }
559
560 #[test]
561 fn test_keep_alive() {
562 let (event_tx, mut event_rx) = mpsc::unbounded_channel();
563 let (_action_tx, action_rx) = mpsc::unbounded_channel();
564
565 let mut config = RelayConnectionConfig::default();
566 config.keep_alive_interval = Duration::from_millis(1);
567
568 let connection = RelayConnection::new(
569 123,
570 test_addr(),
571 config,
572 event_tx,
573 action_rx,
574 );
575
576 assert!(!connection.should_send_keep_alive());
578
579 std::thread::sleep(Duration::from_millis(2));
581
582 assert!(connection.should_send_keep_alive());
584
585 assert!(connection.send_keep_alive().is_ok());
587
588 assert!(!connection.should_send_keep_alive());
590 }
591
592 #[test]
593 fn test_connection_stats() {
594 let (event_tx, _event_rx) = mpsc::unbounded_channel();
595 let (_action_tx, action_rx) = mpsc::unbounded_channel();
596
597 let connection = RelayConnection::new(
598 123,
599 test_addr(),
600 RelayConnectionConfig::default(),
601 event_tx,
602 action_rx,
603 );
604
605 connection.send_data(vec![1, 2, 3]).unwrap();
607 connection.receive_data(vec![4, 5, 6, 7]).unwrap();
608
609 let stats = connection.get_stats();
610 assert_eq!(stats.session_id, 123);
611 assert_eq!(stats.peer_addr, test_addr());
612 assert!(stats.is_active);
613 assert_eq!(stats.bytes_sent, 3);
614 assert_eq!(stats.bytes_received, 4);
615 assert_eq!(stats.outgoing_queue_size, 1);
616 assert_eq!(stats.incoming_queue_size, 1);
617 }
618}