1use std::collections::VecDeque;
19use std::io::{Read as IoRead, Write as IoWrite};
20use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
21use std::sync::Arc;
22use std::sync::atomic::{AtomicBool, Ordering};
23use std::time::Duration;
24
25use noxu_sync::{Condvar, Mutex};
26
27use crate::error::{RepError, Result};
28
29pub const MAX_FRAME_PAYLOAD: usize = 64 * 1024 * 1024;
41
42pub trait Channel: Send + Sync {
49 fn send(&self, data: &[u8]) -> Result<()>;
51
52 fn receive(&self, timeout: Duration) -> Result<Option<Vec<u8>>>;
55
56 fn close(&self) -> Result<()>;
59
60 fn is_open(&self) -> bool;
62}
63
64struct ChannelQueue {
69 queue: Mutex<VecDeque<Vec<u8>>>,
70 condvar: Condvar,
71 writer_closed: AtomicBool,
74}
75
76impl ChannelQueue {
77 fn new() -> Self {
78 Self {
79 queue: Mutex::new(VecDeque::new()),
80 condvar: Condvar::new(),
81 writer_closed: AtomicBool::new(false),
82 }
83 }
84
85 fn push(&self, data: Vec<u8>) {
86 let mut q = self.queue.lock();
87 q.push_back(data);
88 self.condvar.notify_one();
89 }
90
91 fn close_writer(&self) {
93 self.writer_closed.store(true, Ordering::SeqCst);
94 self.condvar.notify_all();
95 }
96
97 fn pop(
101 &self,
102 timeout: Duration,
103 ) -> std::result::Result<Option<Vec<u8>>, ()> {
104 let mut q = self.queue.lock();
105 if q.is_empty() {
106 if self.writer_closed.load(Ordering::SeqCst) {
107 return Err(());
108 }
109 self.condvar.wait_for(&mut q, timeout);
110 }
111 if let Some(data) = q.pop_front() {
112 Ok(Some(data))
113 } else if self.writer_closed.load(Ordering::SeqCst) {
114 Err(())
115 } else {
116 Ok(None)
117 }
118 }
119}
120
121pub struct LocalChannel {
127 send_queue: Arc<ChannelQueue>,
129 recv_queue: Arc<ChannelQueue>,
131 open: AtomicBool,
133}
134
135impl LocalChannel {
136 fn new(
137 send_queue: Arc<ChannelQueue>,
138 recv_queue: Arc<ChannelQueue>,
139 ) -> Self {
140 Self { send_queue, recv_queue, open: AtomicBool::new(true) }
141 }
142}
143
144impl Channel for LocalChannel {
145 fn send(&self, data: &[u8]) -> Result<()> {
146 if !self.is_open() {
147 return Err(RepError::ChannelClosed("channel is closed".into()));
148 }
149 self.send_queue.push(data.to_vec());
150 Ok(())
151 }
152
153 fn receive(&self, timeout: Duration) -> Result<Option<Vec<u8>>> {
154 if !self.is_open() {
155 return Err(RepError::ChannelClosed("channel is closed".into()));
156 }
157 self.recv_queue.pop(timeout).map_err(|()| {
158 RepError::ChannelClosed("peer closed the channel".into())
159 })
160 }
161
162 fn close(&self) -> Result<()> {
163 self.open.store(false, Ordering::SeqCst);
164 self.send_queue.close_writer();
167 self.recv_queue.condvar.notify_all();
169 Ok(())
170 }
171
172 fn is_open(&self) -> bool {
173 self.open.load(Ordering::SeqCst)
174 }
175}
176
177pub struct LocalChannelPair {
182 pub channel_a: LocalChannel,
183 pub channel_b: LocalChannel,
184}
185
186impl LocalChannelPair {
187 pub fn new() -> Self {
189 let queue_a_to_b = Arc::new(ChannelQueue::new());
190 let queue_b_to_a = Arc::new(ChannelQueue::new());
191
192 let channel_a = LocalChannel::new(
193 Arc::clone(&queue_a_to_b),
194 Arc::clone(&queue_b_to_a),
195 );
196 let channel_b = LocalChannel::new(
197 Arc::clone(&queue_b_to_a),
198 Arc::clone(&queue_a_to_b),
199 );
200
201 Self { channel_a, channel_b }
202 }
203}
204
205impl Default for LocalChannelPair {
206 fn default() -> Self {
207 Self::new()
208 }
209}
210
211pub struct TcpChannel {
223 stream: Arc<Mutex<TcpStream>>,
227 open: AtomicBool,
229}
230
231impl TcpChannel {
232 pub fn new(stream: TcpStream) -> Self {
237 Self {
238 stream: Arc::new(Mutex::new(stream)),
239 open: AtomicBool::new(true),
240 }
241 }
242
243 pub fn connect(addr: SocketAddr) -> Result<Self> {
248 let stream = TcpStream::connect_timeout(&addr, Duration::from_secs(30))
249 .map_err(|e| RepError::NetworkError(e.to_string()))?;
250 Ok(Self::new(stream))
251 }
252
253 pub fn connect_host(host: &str, port: u16) -> Result<Self> {
262 let addrs: Vec<SocketAddr> = (host, port)
263 .to_socket_addrs()
264 .map_err(|e| {
265 RepError::NetworkError(format!(
266 "DNS resolution failed for {host}:{port}: {e}"
267 ))
268 })?
269 .collect();
270
271 if addrs.is_empty() {
272 return Err(RepError::NetworkError(format!(
273 "no addresses resolved for {host}:{port}"
274 )));
275 }
276
277 let mut sorted = addrs;
279 sorted.sort_by_key(|a| if a.is_ipv6() { 0u8 } else { 1u8 });
280
281 let mut last_err = None;
282 for addr in &sorted {
283 match TcpStream::connect_timeout(addr, Duration::from_secs(30)) {
284 Ok(stream) => return Ok(Self::new(stream)),
285 Err(e) => last_err = Some(e),
286 }
287 }
288
289 Err(RepError::NetworkError(format!(
290 "could not connect to {host}:{port}: {}",
291 last_err.unwrap()
292 )))
293 }
294
295 pub fn bind_dual_stack(port: u16) -> Result<TcpChannelListener> {
302 if let Ok(listener) = TcpListener::bind(format!("[::]:{}", port)) {
304 return Ok(TcpChannelListener { listener });
305 }
306 let addr: SocketAddr =
308 format!("0.0.0.0:{port}").parse().map_err(|e| {
309 RepError::NetworkError(format!("invalid bind addr: {e}"))
310 })?;
311 TcpChannelListener::bind(addr)
312 }
313}
314
315impl Channel for TcpChannel {
316 fn send(&self, data: &[u8]) -> Result<()> {
321 if !self.is_open() {
322 return Err(RepError::ChannelClosed("TcpChannel is closed".into()));
323 }
324 let len = data.len() as u32;
325 let mut stream = self.stream.lock();
326 stream.set_write_timeout(Some(Duration::from_secs(30))).ok();
328 stream
329 .write_all(&len.to_le_bytes())
330 .map_err(|e| RepError::NetworkError(e.to_string()))?;
331 stream
332 .write_all(data)
333 .map_err(|e| RepError::NetworkError(e.to_string()))?;
334 stream.flush().map_err(|e| RepError::NetworkError(e.to_string()))?;
335 Ok(())
336 }
337
338 fn receive(&self, timeout: Duration) -> Result<Option<Vec<u8>>> {
344 if !self.is_open() {
345 return Err(RepError::ChannelClosed("TcpChannel is closed".into()));
346 }
347
348 let mut stream = self.stream.lock();
349
350 stream
352 .set_read_timeout(Some(timeout))
353 .map_err(|e| RepError::NetworkError(e.to_string()))?;
354
355 let mut len_buf = [0u8; 4];
357 match stream.read_exact(&mut len_buf) {
358 Ok(()) => {}
359 Err(e) => {
360 if e.kind() == std::io::ErrorKind::WouldBlock
362 || e.kind() == std::io::ErrorKind::TimedOut
363 {
364 return Ok(None);
365 }
366 if e.kind() == std::io::ErrorKind::UnexpectedEof {
368 return Err(RepError::ChannelClosed(
369 "connection closed by peer".into(),
370 ));
371 }
372 return Err(RepError::NetworkError(e.to_string()));
373 }
374 }
375
376 let payload_len = u32::from_le_bytes(len_buf) as usize;
377 if payload_len > MAX_FRAME_PAYLOAD {
378 return Err(RepError::ProtocolError(format!(
379 "frame payload too large: {} > {}",
380 payload_len, MAX_FRAME_PAYLOAD
381 )));
382 }
383
384 let payload_timeout = timeout.max(Duration::from_secs(30));
390 stream.set_read_timeout(Some(payload_timeout)).ok();
391
392 let mut payload = vec![0u8; payload_len];
393 stream
394 .read_exact(&mut payload)
395 .map_err(|e| RepError::NetworkError(e.to_string()))?;
396
397 Ok(Some(payload))
398 }
399
400 fn close(&self) -> Result<()> {
402 self.open.store(false, Ordering::SeqCst);
403 let stream = self.stream.lock();
404 stream
405 .shutdown(std::net::Shutdown::Both)
406 .map_err(|e| RepError::NetworkError(e.to_string()))
407 }
408
409 fn is_open(&self) -> bool {
410 self.open.load(Ordering::SeqCst)
411 }
412}
413
414pub struct TcpChannelListener {
424 listener: TcpListener,
425}
426
427impl TcpChannelListener {
428 pub fn bind(addr: SocketAddr) -> Result<Self> {
430 let listener = TcpListener::bind(addr)
431 .map_err(|e| RepError::NetworkError(e.to_string()))?;
432 Ok(Self { listener })
433 }
434
435 pub fn local_addr(&self) -> Result<SocketAddr> {
437 self.listener
438 .local_addr()
439 .map_err(|e| RepError::NetworkError(e.to_string()))
440 }
441
442 pub fn accept(&self) -> Result<TcpChannel> {
446 let (stream, _peer) = self
447 .listener
448 .accept()
449 .map_err(|e| RepError::NetworkError(e.to_string()))?;
450 Ok(TcpChannel::new(stream))
451 }
452
453 pub fn set_accept_timeout(&self, timeout: Option<Duration>) -> Result<()> {
458 #[cfg(unix)]
459 {
460 use std::os::fd::AsRawFd;
461 let fd = self.listener.as_raw_fd();
462 let tv = match timeout {
463 Some(d) => libc::timeval {
464 tv_sec: d.as_secs() as libc::time_t,
465 tv_usec: d.subsec_micros() as libc::suseconds_t,
466 },
467 None => libc::timeval { tv_sec: 0, tv_usec: 0 },
468 };
469 let rc = unsafe {
470 libc::setsockopt(
471 fd,
472 libc::SOL_SOCKET,
473 libc::SO_RCVTIMEO,
474 &tv as *const _ as *const libc::c_void,
475 std::mem::size_of::<libc::timeval>() as libc::socklen_t,
476 )
477 };
478 if rc != 0 {
479 return Err(RepError::NetworkError(
480 std::io::Error::last_os_error().to_string(),
481 ));
482 }
483 }
484 #[cfg(not(unix))]
485 {
486 let _ = timeout;
487 }
488 Ok(())
489 }
490}
491
492#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
497use crate::tls::TlsConfig;
498
499#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
504trait TlsStreamOps: Send + 'static {
505 fn read_exact_buf(&mut self, buf: &mut [u8]) -> std::io::Result<()>;
506 fn write_all_buf(&mut self, buf: &[u8]) -> std::io::Result<()>;
507 fn flush_buf(&mut self) -> std::io::Result<()>;
508 fn set_read_timeout_inner(
509 &mut self,
510 dur: Option<Duration>,
511 ) -> std::io::Result<()>;
512 fn set_write_timeout_inner(
513 &mut self,
514 dur: Option<Duration>,
515 ) -> std::io::Result<()>;
516 fn shutdown_inner(&self) -> std::io::Result<()>;
517}
518
519#[cfg(feature = "tls-rustls")]
522impl TlsStreamOps for rustls::StreamOwned<rustls::ServerConnection, TcpStream> {
523 fn read_exact_buf(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
524 IoRead::read_exact(self, buf)
525 }
526 fn write_all_buf(&mut self, buf: &[u8]) -> std::io::Result<()> {
527 IoWrite::write_all(self, buf)
528 }
529 fn flush_buf(&mut self) -> std::io::Result<()> {
530 IoWrite::flush(self)
531 }
532 fn set_read_timeout_inner(
533 &mut self,
534 dur: Option<Duration>,
535 ) -> std::io::Result<()> {
536 self.sock.set_read_timeout(dur)
537 }
538 fn set_write_timeout_inner(
539 &mut self,
540 dur: Option<Duration>,
541 ) -> std::io::Result<()> {
542 self.sock.set_write_timeout(dur)
543 }
544 fn shutdown_inner(&self) -> std::io::Result<()> {
545 self.sock.shutdown(std::net::Shutdown::Both)
546 }
547}
548
549#[cfg(feature = "tls-rustls")]
550impl TlsStreamOps for rustls::StreamOwned<rustls::ClientConnection, TcpStream> {
551 fn read_exact_buf(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
552 IoRead::read_exact(self, buf)
553 }
554 fn write_all_buf(&mut self, buf: &[u8]) -> std::io::Result<()> {
555 IoWrite::write_all(self, buf)
556 }
557 fn flush_buf(&mut self) -> std::io::Result<()> {
558 IoWrite::flush(self)
559 }
560 fn set_read_timeout_inner(
561 &mut self,
562 dur: Option<Duration>,
563 ) -> std::io::Result<()> {
564 self.sock.set_read_timeout(dur)
565 }
566 fn set_write_timeout_inner(
567 &mut self,
568 dur: Option<Duration>,
569 ) -> std::io::Result<()> {
570 self.sock.set_write_timeout(dur)
571 }
572 fn shutdown_inner(&self) -> std::io::Result<()> {
573 self.sock.shutdown(std::net::Shutdown::Both)
574 }
575}
576
577#[cfg(feature = "tls-native")]
580impl TlsStreamOps for native_tls::TlsStream<TcpStream> {
581 fn read_exact_buf(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
582 IoRead::read_exact(self, buf)
583 }
584 fn write_all_buf(&mut self, buf: &[u8]) -> std::io::Result<()> {
585 IoWrite::write_all(self, buf)
586 }
587 fn flush_buf(&mut self) -> std::io::Result<()> {
588 IoWrite::flush(self)
589 }
590 fn set_read_timeout_inner(
591 &mut self,
592 dur: Option<Duration>,
593 ) -> std::io::Result<()> {
594 self.get_ref().set_read_timeout(dur)
595 }
596 fn set_write_timeout_inner(
597 &mut self,
598 dur: Option<Duration>,
599 ) -> std::io::Result<()> {
600 self.get_ref().set_write_timeout(dur)
601 }
602 fn shutdown_inner(&self) -> std::io::Result<()> {
603 self.get_ref().shutdown(std::net::Shutdown::Both)
604 }
605}
606
607#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
633pub struct TlsTcpChannel {
634 stream: Arc<std::sync::Mutex<Box<dyn TlsStreamOps>>>,
635 open: AtomicBool,
636}
637
638#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
639impl TlsTcpChannel {
640 fn wrap(stream: Box<dyn TlsStreamOps>) -> Self {
641 Self {
642 stream: Arc::new(std::sync::Mutex::new(stream)),
643 open: AtomicBool::new(true),
644 }
645 }
646
647 pub fn connect_with_tls(addr: SocketAddr, tls: &TlsConfig) -> Result<Self> {
652 #[cfg(feature = "tls-rustls")]
653 {
654 return Self::connect_rustls(addr, tls);
655 }
656 #[cfg(all(feature = "tls-native", not(feature = "tls-rustls")))]
657 {
658 return Self::connect_native(addr, tls);
659 }
660 #[allow(unreachable_code)]
661 Err(RepError::NetworkError("no TLS feature enabled".into()))
662 }
663
664 #[cfg(feature = "tls-rustls")]
665 fn connect_rustls(addr: SocketAddr, tls: &TlsConfig) -> Result<Self> {
666 use rustls::pki_types::ServerName;
667 let cfg = tls.to_rustls_client_config()?;
668 let server_name = ServerName::try_from(tls.server_name.clone())
669 .map_err(|e| {
670 RepError::NetworkError(format!("invalid server name: {e}"))
671 })?;
672 let tcp = TcpStream::connect_timeout(&addr, Duration::from_secs(30))
673 .map_err(|e| RepError::NetworkError(e.to_string()))?;
674 let conn =
675 rustls::ClientConnection::new(cfg, server_name).map_err(|e| {
676 RepError::NetworkError(format!("TLS client init: {e}"))
677 })?;
678 let stream = rustls::StreamOwned::new(conn, tcp);
679 Ok(Self::wrap(Box::new(stream)))
680 }
681
682 #[cfg(feature = "tls-native")]
683 fn connect_native(addr: SocketAddr, tls: &TlsConfig) -> Result<Self> {
684 let connector = tls.to_native_connector()?;
685 let tcp = TcpStream::connect_timeout(&addr, Duration::from_secs(30))
686 .map_err(|e| RepError::NetworkError(e.to_string()))?;
687 let stream = connector.connect(&tls.server_name, tcp).map_err(|e| {
688 RepError::NetworkError(format!("TLS handshake: {e}"))
689 })?;
690 Ok(Self::wrap(Box::new(stream)))
691 }
692}
693
694#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
695impl Channel for TlsTcpChannel {
696 fn send(&self, data: &[u8]) -> Result<()> {
697 if !self.is_open() {
698 return Err(RepError::ChannelClosed(
699 "TlsTcpChannel is closed".into(),
700 ));
701 }
702 let len = data.len() as u32;
703 let mut s = self.stream.lock().map_err(|_| {
704 RepError::NetworkError("TLS stream lock poisoned".into())
705 })?;
706 s.set_write_timeout_inner(Some(Duration::from_secs(30)))
707 .map_err(|e| RepError::NetworkError(e.to_string()))?;
708 s.write_all_buf(&len.to_le_bytes())
709 .map_err(|e| RepError::NetworkError(e.to_string()))?;
710 s.write_all_buf(data)
711 .map_err(|e| RepError::NetworkError(e.to_string()))?;
712 s.flush_buf().map_err(|e| RepError::NetworkError(e.to_string()))?;
713 Ok(())
714 }
715
716 fn receive(&self, timeout: Duration) -> Result<Option<Vec<u8>>> {
717 if !self.is_open() {
718 return Err(RepError::ChannelClosed(
719 "TlsTcpChannel is closed".into(),
720 ));
721 }
722 let mut s = self.stream.lock().map_err(|_| {
723 RepError::NetworkError("TLS stream lock poisoned".into())
724 })?;
725 s.set_read_timeout_inner(Some(timeout))
726 .map_err(|e| RepError::NetworkError(e.to_string()))?;
727 let mut len_buf = [0u8; 4];
728 match s.read_exact_buf(&mut len_buf) {
729 Ok(()) => {}
730 Err(e) => {
731 if e.kind() == std::io::ErrorKind::WouldBlock
732 || e.kind() == std::io::ErrorKind::TimedOut
733 {
734 return Ok(None);
735 }
736 if e.kind() == std::io::ErrorKind::UnexpectedEof {
737 return Err(RepError::ChannelClosed(
738 "connection closed by peer".into(),
739 ));
740 }
741 return Err(RepError::NetworkError(e.to_string()));
742 }
743 }
744 let payload_len = u32::from_le_bytes(len_buf) as usize;
745 if payload_len > MAX_FRAME_PAYLOAD {
746 return Err(RepError::ProtocolError(format!(
747 "frame payload too large: {} > {}",
748 payload_len, MAX_FRAME_PAYLOAD
749 )));
750 }
751 let payload_timeout = timeout.max(Duration::from_secs(30));
752 s.set_read_timeout_inner(Some(payload_timeout))
753 .map_err(|e| RepError::NetworkError(e.to_string()))?;
754 let mut payload = vec![0u8; payload_len];
755 s.read_exact_buf(&mut payload)
756 .map_err(|e| RepError::NetworkError(e.to_string()))?;
757 Ok(Some(payload))
758 }
759
760 fn close(&self) -> Result<()> {
761 self.open.store(false, Ordering::SeqCst);
762 let s = self.stream.lock().map_err(|_| {
763 RepError::NetworkError("TLS stream lock poisoned".into())
764 })?;
765 s.shutdown_inner().map_err(|e| RepError::NetworkError(e.to_string()))
766 }
767
768 fn is_open(&self) -> bool {
769 self.open.load(Ordering::SeqCst)
770 }
771}
772
773#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
776enum TlsAcceptorImpl {
777 #[cfg(feature = "tls-rustls")]
778 Rustls(std::sync::Arc<rustls::ServerConfig>),
779 #[cfg(feature = "tls-native")]
780 Native(native_tls::TlsAcceptor),
781}
782
783#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
789pub struct TlsTcpChannelListener {
790 listener: TcpListener,
791 acceptor: TlsAcceptorImpl,
792}
793
794#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
795impl TlsTcpChannelListener {
796 pub fn bind_with_tls(addr: SocketAddr, tls: &TlsConfig) -> Result<Self> {
801 let listener = TcpListener::bind(addr)
802 .map_err(|e| RepError::NetworkError(e.to_string()))?;
803 #[cfg(feature = "tls-rustls")]
804 let acceptor = {
805 let cfg = tls.to_rustls_server_config()?;
806 TlsAcceptorImpl::Rustls(cfg)
807 };
808 #[cfg(all(feature = "tls-native", not(feature = "tls-rustls")))]
809 let acceptor = {
810 let a = tls.to_native_acceptor()?;
811 TlsAcceptorImpl::Native(a)
812 };
813 Ok(Self { listener, acceptor })
814 }
815
816 pub fn local_addr(&self) -> Result<SocketAddr> {
818 self.listener
819 .local_addr()
820 .map_err(|e| RepError::NetworkError(e.to_string()))
821 }
822
823 pub fn accept(&self) -> Result<TlsTcpChannel> {
825 let (tcp, _peer) = self
826 .listener
827 .accept()
828 .map_err(|e| RepError::NetworkError(e.to_string()))?;
829 match &self.acceptor {
830 #[cfg(feature = "tls-rustls")]
831 TlsAcceptorImpl::Rustls(cfg) => {
832 let conn = rustls::ServerConnection::new(Arc::clone(cfg))
833 .map_err(|e| {
834 RepError::NetworkError(format!("TLS server init: {e}"))
835 })?;
836 let stream = rustls::StreamOwned::new(conn, tcp);
837 Ok(TlsTcpChannel::wrap(Box::new(stream)))
838 }
839 #[cfg(feature = "tls-native")]
840 TlsAcceptorImpl::Native(acceptor) => {
841 let stream = acceptor.accept(tcp).map_err(|e| {
842 RepError::NetworkError(format!("TLS handshake: {e}"))
843 })?;
844 Ok(TlsTcpChannel::wrap(Box::new(stream)))
845 }
846 }
847 }
848}
849
850#[cfg(test)]
851mod tests {
852 use super::*;
853 use std::time::Duration;
854
855 #[test]
856 fn test_send_receive_basic() {
857 let pair = LocalChannelPair::new();
858 let msg = b"hello world";
859 pair.channel_a.send(msg).unwrap();
860 let received = pair.channel_b.receive(Duration::from_secs(1)).unwrap();
861 assert_eq!(received, Some(msg.to_vec()));
862 }
863
864 #[test]
865 fn test_bidirectional() {
866 let pair = LocalChannelPair::new();
867
868 pair.channel_a.send(b"from a").unwrap();
869 pair.channel_b.send(b"from b").unwrap();
870
871 let recv_b = pair.channel_b.receive(Duration::from_secs(1)).unwrap();
872 assert_eq!(recv_b, Some(b"from a".to_vec()));
873
874 let recv_a = pair.channel_a.receive(Duration::from_secs(1)).unwrap();
875 assert_eq!(recv_a, Some(b"from b".to_vec()));
876 }
877
878 #[test]
879 fn test_multiple_messages_fifo() {
880 let pair = LocalChannelPair::new();
881 pair.channel_a.send(b"first").unwrap();
882 pair.channel_a.send(b"second").unwrap();
883 pair.channel_a.send(b"third").unwrap();
884
885 assert_eq!(
886 pair.channel_b.receive(Duration::from_secs(1)).unwrap(),
887 Some(b"first".to_vec())
888 );
889 assert_eq!(
890 pair.channel_b.receive(Duration::from_secs(1)).unwrap(),
891 Some(b"second".to_vec())
892 );
893 assert_eq!(
894 pair.channel_b.receive(Duration::from_secs(1)).unwrap(),
895 Some(b"third".to_vec())
896 );
897 }
898
899 #[test]
900 fn test_receive_timeout_empty_queue() {
901 let pair = LocalChannelPair::new();
902 let result = pair.channel_b.receive(Duration::from_millis(50)).unwrap();
903 assert_eq!(result, None);
904 }
905
906 #[test]
907 fn test_send_after_close_fails() {
908 let pair = LocalChannelPair::new();
909 pair.channel_a.close().unwrap();
910 let result = pair.channel_a.send(b"should fail");
911 assert!(result.is_err());
912 }
913
914 #[test]
915 fn test_receive_after_close_fails() {
916 let pair = LocalChannelPair::new();
917 pair.channel_b.close().unwrap();
918 let result = pair.channel_b.receive(Duration::from_millis(10));
919 assert!(result.is_err());
920 }
921
922 #[test]
923 fn test_is_open() {
924 let pair = LocalChannelPair::new();
925 assert!(pair.channel_a.is_open());
926 assert!(pair.channel_b.is_open());
927
928 pair.channel_a.close().unwrap();
929 assert!(!pair.channel_a.is_open());
930 assert!(pair.channel_b.is_open());
932 }
933
934 #[test]
935 fn test_empty_message() {
936 let pair = LocalChannelPair::new();
937 pair.channel_a.send(b"").unwrap();
938 let received = pair.channel_b.receive(Duration::from_secs(1)).unwrap();
939 assert_eq!(received, Some(vec![]));
940 }
941
942 #[test]
943 fn test_large_message() {
944 let pair = LocalChannelPair::new();
945 let large = vec![0xABu8; 1024 * 1024]; pair.channel_a.send(&large).unwrap();
947 let received = pair.channel_b.receive(Duration::from_secs(1)).unwrap();
948 assert_eq!(received, Some(large));
949 }
950
951 #[test]
952 fn test_concurrent_send_receive() {
953 let pair = LocalChannelPair::new();
954 let queue_send = Arc::clone(&pair.channel_a.send_queue);
956 let _queue_recv = Arc::clone(&pair.channel_b.recv_queue);
957
958 let _channel_b_send = Arc::new(ChannelQueue::new());
959 let _channel_b_recv = Arc::clone(&queue_send); std::thread::scope(|s| {
963 let a = &pair.channel_a;
964 let b = &pair.channel_b;
965
966 let handle = s.spawn(|| {
967 let msg = b.receive(Duration::from_secs(5)).unwrap();
968 assert_eq!(msg, Some(b"concurrent".to_vec()));
969 b.send(b"ack").unwrap();
970 });
971
972 a.send(b"concurrent").unwrap();
973 let ack = a.receive(Duration::from_secs(5)).unwrap();
974 assert_eq!(ack, Some(b"ack".to_vec()));
975 handle.join().unwrap();
976 });
977 }
978
979 #[test]
980 fn test_default_trait() {
981 let pair = LocalChannelPair::default();
982 assert!(pair.channel_a.is_open());
983 assert!(pair.channel_b.is_open());
984 }
985
986 #[test]
991 fn test_tcp_channel_send_receive() {
992 use std::net::TcpListener;
993
994 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
995 let addr = listener.local_addr().unwrap();
996
997 let handle = std::thread::spawn(move || {
998 let (stream, _) = listener.accept().unwrap();
999 let ch = TcpChannel::new(stream);
1000 let msg = ch.receive(Duration::from_secs(5)).unwrap();
1001 assert_eq!(msg, Some(b"hello tcp".to_vec()));
1002 ch.send(b"world").unwrap();
1003 });
1004
1005 let client = TcpChannel::connect(addr).unwrap();
1006 client.send(b"hello tcp").unwrap();
1007 let reply = client.receive(Duration::from_secs(5)).unwrap();
1008 assert_eq!(reply, Some(b"world".to_vec()));
1009
1010 handle.join().unwrap();
1011 }
1012
1013 #[test]
1014 fn test_tcp_channel_multiple_messages() {
1015 use std::net::TcpListener;
1016
1017 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1018 let addr = listener.local_addr().unwrap();
1019
1020 let handle = std::thread::spawn(move || {
1021 let (stream, _) = listener.accept().unwrap();
1022 let ch = TcpChannel::new(stream);
1023 for i in 0u8..5 {
1024 let msg = ch.receive(Duration::from_secs(5)).unwrap().unwrap();
1025 assert_eq!(msg, vec![i]);
1026 }
1027 });
1028
1029 let client = TcpChannel::connect(addr).unwrap();
1030 for i in 0u8..5 {
1031 client.send(&[i]).unwrap();
1032 }
1033 handle.join().unwrap();
1034 }
1035
1036 #[test]
1037 fn test_tcp_channel_receive_timeout() {
1038 use std::net::TcpListener;
1039
1040 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1041 let addr = listener.local_addr().unwrap();
1042
1043 let handle = std::thread::spawn(move || {
1045 let (_stream, _) = listener.accept().unwrap();
1046 std::thread::sleep(Duration::from_secs(2));
1047 });
1048
1049 let client = TcpChannel::connect(addr).unwrap();
1050 let result = client.receive(Duration::from_millis(100)).unwrap();
1051 assert_eq!(result, None, "expected timeout → None");
1052
1053 handle.join().unwrap();
1054 }
1055
1056 #[test]
1057 fn test_tcp_channel_is_open_and_close() {
1058 use std::net::TcpListener;
1059
1060 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1061 let addr = listener.local_addr().unwrap();
1062
1063 let handle = std::thread::spawn(move || {
1064 let (_stream, _) = listener.accept().unwrap();
1065 std::thread::sleep(Duration::from_millis(200));
1066 });
1067
1068 let client = TcpChannel::connect(addr).unwrap();
1069 assert!(client.is_open());
1070 client.close().unwrap();
1071 assert!(!client.is_open());
1072
1073 handle.join().unwrap();
1074 }
1075
1076 #[test]
1077 fn test_tcp_channel_large_payload() {
1078 use std::net::TcpListener;
1079
1080 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1081 let addr = listener.local_addr().unwrap();
1082 let payload: Vec<u8> = (0..65536).map(|i| (i % 256) as u8).collect();
1083 let expected = payload.clone();
1084
1085 let handle = std::thread::spawn(move || {
1086 let (stream, _) = listener.accept().unwrap();
1087 let ch = TcpChannel::new(stream);
1088 let msg = ch.receive(Duration::from_secs(5)).unwrap().unwrap();
1089 assert_eq!(msg, expected);
1090 });
1091
1092 let client = TcpChannel::connect(addr).unwrap();
1093 client.send(&payload).unwrap();
1094 handle.join().unwrap();
1095 }
1096
1097 #[test]
1098 fn test_tcp_channel_listener_bind_and_accept() {
1099 let listener =
1100 TcpChannelListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
1101 let addr = listener.local_addr().unwrap();
1102
1103 let handle = std::thread::spawn(move || {
1104 let ch = listener.accept().unwrap();
1105 let msg = ch.receive(Duration::from_secs(5)).unwrap();
1106 assert_eq!(msg, Some(b"ping".to_vec()));
1107 });
1108
1109 let client = TcpChannel::connect(addr).unwrap();
1110 client.send(b"ping").unwrap();
1111 handle.join().unwrap();
1112 }
1113
1114 #[test]
1119 fn test_tcp_channel_rejects_oversize_frame() {
1120 use std::net::TcpListener;
1121
1122 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1123 let addr = listener.local_addr().unwrap();
1124
1125 let handle = std::thread::spawn(move || {
1129 let (mut stream, _) = listener.accept().unwrap();
1130 let oversized = (crate::net::channel::MAX_FRAME_PAYLOAD as u32)
1131 .saturating_add(1);
1132 stream.write_all(&oversized.to_le_bytes()).unwrap();
1133 std::thread::sleep(Duration::from_millis(200));
1135 });
1136
1137 let client = TcpChannel::connect(addr).unwrap();
1138 let err = client
1139 .receive(Duration::from_secs(5))
1140 .expect_err("oversize frame must be rejected");
1141 match err {
1142 RepError::ProtocolError(msg) => {
1143 assert!(
1144 msg.contains("frame payload too large"),
1145 "unexpected protocol-error message: {}",
1146 msg
1147 );
1148 }
1149 other => panic!("expected ProtocolError, got {:?}", other),
1150 }
1151
1152 handle.join().unwrap();
1153 }
1154
1155 #[cfg(feature = "tls-rustls")]
1160 mod tls_tests {
1161 use super::*;
1162 use crate::tls::TlsConfig;
1163
1164 #[test]
1165 fn test_tls_tcp_send_receive() {
1166 let tls = TlsConfig::insecure("localhost");
1167 let listener = TlsTcpChannelListener::bind_with_tls(
1168 "127.0.0.1:0".parse().unwrap(),
1169 &tls,
1170 )
1171 .unwrap();
1172 let addr = listener.local_addr().unwrap();
1173
1174 let handle = std::thread::spawn(move || {
1175 let ch = listener.accept().unwrap();
1176 let msg = ch.receive(Duration::from_secs(5)).unwrap();
1177 assert_eq!(msg, Some(b"hello tls".to_vec()));
1178 ch.send(b"world tls").unwrap();
1179 });
1180
1181 let client = TlsTcpChannel::connect_with_tls(addr, &tls).unwrap();
1182 client.send(b"hello tls").unwrap();
1183 let reply = client.receive(Duration::from_secs(5)).unwrap();
1184 assert_eq!(reply, Some(b"world tls".to_vec()));
1185
1186 handle.join().unwrap();
1187 }
1188
1189 #[test]
1190 fn test_tls_tcp_multiple_messages() {
1191 let tls = TlsConfig::insecure("localhost");
1192 let listener = TlsTcpChannelListener::bind_with_tls(
1193 "127.0.0.1:0".parse().unwrap(),
1194 &tls,
1195 )
1196 .unwrap();
1197 let addr = listener.local_addr().unwrap();
1198
1199 let handle = std::thread::spawn(move || {
1200 let ch = listener.accept().unwrap();
1201 for i in 0u8..4 {
1202 let msg =
1203 ch.receive(Duration::from_secs(5)).unwrap().unwrap();
1204 assert_eq!(msg, vec![i]);
1205 }
1206 });
1207
1208 let client = TlsTcpChannel::connect_with_tls(addr, &tls).unwrap();
1209 for i in 0u8..4 {
1210 client.send(&[i]).unwrap();
1211 }
1212 handle.join().unwrap();
1213 }
1214
1215 #[test]
1216 fn test_tls_tcp_large_payload() {
1217 let tls = TlsConfig::insecure("localhost");
1218 let listener = TlsTcpChannelListener::bind_with_tls(
1219 "127.0.0.1:0".parse().unwrap(),
1220 &tls,
1221 )
1222 .unwrap();
1223 let addr = listener.local_addr().unwrap();
1224 let payload: Vec<u8> =
1225 (0..65536).map(|i| (i % 256) as u8).collect();
1226 let expected = payload.clone();
1227
1228 let handle = std::thread::spawn(move || {
1229 let ch = listener.accept().unwrap();
1230 let msg = ch.receive(Duration::from_secs(5)).unwrap().unwrap();
1231 assert_eq!(msg, expected);
1232 });
1233
1234 let client = TlsTcpChannel::connect_with_tls(addr, &tls).unwrap();
1235 client.send(&payload).unwrap();
1236 handle.join().unwrap();
1237 }
1238
1239 #[test]
1240 fn test_tls_tcp_receive_timeout() {
1241 let tls = TlsConfig::insecure("localhost");
1242 let listener = TlsTcpChannelListener::bind_with_tls(
1243 "127.0.0.1:0".parse().unwrap(),
1244 &tls,
1245 )
1246 .unwrap();
1247 let addr = listener.local_addr().unwrap();
1248
1249 let handle = std::thread::spawn(move || {
1251 let _ch = listener.accept().unwrap();
1252 std::thread::sleep(Duration::from_secs(2));
1253 });
1254
1255 let client = TlsTcpChannel::connect_with_tls(addr, &tls).unwrap();
1256 let result = client.receive(Duration::from_millis(500)).unwrap();
1259 assert_eq!(result, None, "expected timeout → None");
1260
1261 handle.join().unwrap();
1262 }
1263
1264 #[test]
1265 fn test_tls_tcp_close() {
1266 let tls = TlsConfig::insecure("localhost");
1267 let listener = TlsTcpChannelListener::bind_with_tls(
1268 "127.0.0.1:0".parse().unwrap(),
1269 &tls,
1270 )
1271 .unwrap();
1272 let addr = listener.local_addr().unwrap();
1273
1274 let handle = std::thread::spawn(move || {
1275 let _ch = listener.accept().unwrap();
1276 std::thread::sleep(Duration::from_millis(200));
1277 });
1278
1279 let client = TlsTcpChannel::connect_with_tls(addr, &tls).unwrap();
1280 assert!(client.is_open());
1281 client.close().unwrap();
1282 assert!(!client.is_open());
1283
1284 handle.join().unwrap();
1285 }
1286
1287 #[test]
1290 fn test_tls_tcp_rejects_oversize_frame() {
1291 let tls = TlsConfig::insecure("localhost");
1292 let listener = TlsTcpChannelListener::bind_with_tls(
1293 "127.0.0.1:0".parse().unwrap(),
1294 &tls,
1295 )
1296 .unwrap();
1297 let addr = listener.local_addr().unwrap();
1298
1299 let handle = std::thread::spawn(move || {
1304 let ch = listener.accept().unwrap();
1305 let oversized =
1306 vec![0u8; crate::net::channel::MAX_FRAME_PAYLOAD + 1];
1307 let _ = ch.send(&oversized);
1308 });
1309
1310 let client = TlsTcpChannel::connect_with_tls(addr, &tls).unwrap();
1311 let result = client.receive(Duration::from_secs(10));
1312 let _ = client.close();
1316 let err = result.expect_err("oversize TLS frame must be rejected");
1317 match err {
1318 RepError::ProtocolError(msg) => {
1319 assert!(
1320 msg.contains("frame payload too large"),
1321 "unexpected protocol-error message: {}",
1322 msg
1323 );
1324 }
1325 other => panic!("expected ProtocolError, got {:?}", other),
1326 }
1327 let _ = handle.join();
1328 }
1329 }
1330}