1use std::collections::VecDeque;
19use std::io::{Read as IoRead, Write as IoWrite};
20use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
21use std::sync::Arc;
22use std::sync::atomic::{AtomicBool, Ordering};
23use std::time::Duration;
24
25use noxu_sync::{Condvar, Mutex};
26
27use crate::error::{RepError, Result};
28
29pub const MAX_FRAME_PAYLOAD: usize = 64 * 1024 * 1024;
41
42pub trait Channel: Send + Sync {
49 fn send(&self, data: &[u8]) -> Result<()>;
51
52 fn receive(&self, timeout: Duration) -> Result<Option<Vec<u8>>>;
55
56 fn close(&self) -> Result<()>;
59
60 fn is_open(&self) -> bool;
62}
63
64struct ChannelQueue {
69 queue: Mutex<VecDeque<Vec<u8>>>,
70 condvar: Condvar,
71 writer_closed: AtomicBool,
74}
75
76impl ChannelQueue {
77 fn new() -> Self {
78 Self {
79 queue: Mutex::new(VecDeque::new()),
80 condvar: Condvar::new(),
81 writer_closed: AtomicBool::new(false),
82 }
83 }
84
85 fn push(&self, data: Vec<u8>) {
86 let mut q = self.queue.lock();
87 q.push_back(data);
88 self.condvar.notify_one();
89 }
90
91 fn close_writer(&self) {
93 self.writer_closed.store(true, Ordering::SeqCst);
94 self.condvar.notify_all();
95 }
96
97 fn pop(
101 &self,
102 timeout: Duration,
103 ) -> std::result::Result<Option<Vec<u8>>, ()> {
104 let mut q = self.queue.lock();
105 if q.is_empty() {
106 if self.writer_closed.load(Ordering::SeqCst) {
107 return Err(());
108 }
109 self.condvar.wait_for(&mut q, timeout);
110 }
111 if let Some(data) = q.pop_front() {
112 Ok(Some(data))
113 } else if self.writer_closed.load(Ordering::SeqCst) {
114 Err(())
115 } else {
116 Ok(None)
117 }
118 }
119}
120
121pub struct LocalChannel {
127 send_queue: Arc<ChannelQueue>,
129 recv_queue: Arc<ChannelQueue>,
131 open: AtomicBool,
133}
134
135impl LocalChannel {
136 fn new(
137 send_queue: Arc<ChannelQueue>,
138 recv_queue: Arc<ChannelQueue>,
139 ) -> Self {
140 Self { send_queue, recv_queue, open: AtomicBool::new(true) }
141 }
142}
143
144impl Channel for LocalChannel {
145 fn send(&self, data: &[u8]) -> Result<()> {
146 if !self.is_open() {
147 return Err(RepError::ChannelClosed("channel is closed".into()));
148 }
149 self.send_queue.push(data.to_vec());
150 Ok(())
151 }
152
153 fn receive(&self, timeout: Duration) -> Result<Option<Vec<u8>>> {
154 if !self.is_open() {
155 return Err(RepError::ChannelClosed("channel is closed".into()));
156 }
157 self.recv_queue.pop(timeout).map_err(|()| {
158 RepError::ChannelClosed("peer closed the channel".into())
159 })
160 }
161
162 fn close(&self) -> Result<()> {
163 self.open.store(false, Ordering::SeqCst);
164 self.send_queue.close_writer();
167 self.recv_queue.condvar.notify_all();
169 Ok(())
170 }
171
172 fn is_open(&self) -> bool {
173 self.open.load(Ordering::SeqCst)
174 }
175}
176
177pub struct LocalChannelPair {
182 pub channel_a: LocalChannel,
183 pub channel_b: LocalChannel,
184}
185
186impl LocalChannelPair {
187 pub fn new() -> Self {
189 let queue_a_to_b = Arc::new(ChannelQueue::new());
190 let queue_b_to_a = Arc::new(ChannelQueue::new());
191
192 let channel_a = LocalChannel::new(
193 Arc::clone(&queue_a_to_b),
194 Arc::clone(&queue_b_to_a),
195 );
196 let channel_b = LocalChannel::new(
197 Arc::clone(&queue_b_to_a),
198 Arc::clone(&queue_a_to_b),
199 );
200
201 Self { channel_a, channel_b }
202 }
203}
204
205impl Default for LocalChannelPair {
206 fn default() -> Self {
207 Self::new()
208 }
209}
210
211pub struct TcpChannel {
223 stream: Arc<Mutex<TcpStream>>,
227 open: AtomicBool,
229}
230
231impl TcpChannel {
232 pub fn new(stream: TcpStream) -> Self {
237 Self {
238 stream: Arc::new(Mutex::new(stream)),
239 open: AtomicBool::new(true),
240 }
241 }
242
243 pub fn connect(addr: SocketAddr) -> Result<Self> {
248 let stream = TcpStream::connect_timeout(&addr, Duration::from_secs(30))
249 .map_err(|e| RepError::NetworkError(e.to_string()))?;
250 Ok(Self::new(stream))
251 }
252
253 pub fn connect_host(host: &str, port: u16) -> Result<Self> {
262 let addrs: Vec<SocketAddr> = (host, port)
263 .to_socket_addrs()
264 .map_err(|e| {
265 RepError::NetworkError(format!(
266 "DNS resolution failed for {host}:{port}: {e}"
267 ))
268 })?
269 .collect();
270
271 if addrs.is_empty() {
272 return Err(RepError::NetworkError(format!(
273 "no addresses resolved for {host}:{port}"
274 )));
275 }
276
277 let mut sorted = addrs;
279 sorted.sort_by_key(|a| if a.is_ipv6() { 0u8 } else { 1u8 });
280
281 let mut last_err = None;
282 for addr in &sorted {
283 match TcpStream::connect_timeout(addr, Duration::from_secs(30)) {
284 Ok(stream) => return Ok(Self::new(stream)),
285 Err(e) => last_err = Some(e),
286 }
287 }
288
289 Err(RepError::NetworkError(format!(
290 "could not connect to {host}:{port}: {}",
291 last_err.unwrap()
292 )))
293 }
294
295 pub fn bind_dual_stack(port: u16) -> Result<TcpChannelListener> {
302 if let Ok(listener) = TcpListener::bind(format!("[::]:{}", port)) {
304 return Ok(TcpChannelListener { listener });
305 }
306 let addr: SocketAddr =
308 format!("0.0.0.0:{port}").parse().map_err(|e| {
309 RepError::NetworkError(format!("invalid bind addr: {e}"))
310 })?;
311 TcpChannelListener::bind(addr)
312 }
313}
314
315impl Channel for TcpChannel {
316 fn send(&self, data: &[u8]) -> Result<()> {
321 if !self.is_open() {
322 return Err(RepError::ChannelClosed("TcpChannel is closed".into()));
323 }
324 let len = data.len() as u32;
325 let mut stream = self.stream.lock();
326 stream.set_write_timeout(Some(Duration::from_secs(30))).ok();
328 stream
329 .write_all(&len.to_le_bytes())
330 .map_err(|e| RepError::NetworkError(e.to_string()))?;
331 stream
332 .write_all(data)
333 .map_err(|e| RepError::NetworkError(e.to_string()))?;
334 stream.flush().map_err(|e| RepError::NetworkError(e.to_string()))?;
335 Ok(())
336 }
337
338 fn receive(&self, timeout: Duration) -> Result<Option<Vec<u8>>> {
344 if !self.is_open() {
345 return Err(RepError::ChannelClosed("TcpChannel is closed".into()));
346 }
347
348 let mut stream = self.stream.lock();
349
350 stream
352 .set_read_timeout(Some(timeout))
353 .map_err(|e| RepError::NetworkError(e.to_string()))?;
354
355 let mut len_buf = [0u8; 4];
357 match stream.read_exact(&mut len_buf) {
358 Ok(()) => {}
359 Err(e) => {
360 if e.kind() == std::io::ErrorKind::WouldBlock
362 || e.kind() == std::io::ErrorKind::TimedOut
363 {
364 return Ok(None);
365 }
366 if e.kind() == std::io::ErrorKind::UnexpectedEof {
368 return Err(RepError::ChannelClosed(
369 "connection closed by peer".into(),
370 ));
371 }
372 return Err(RepError::NetworkError(e.to_string()));
373 }
374 }
375
376 let payload_len = u32::from_le_bytes(len_buf) as usize;
377 if payload_len > MAX_FRAME_PAYLOAD {
378 return Err(RepError::ProtocolError(format!(
379 "frame payload too large: {} > {}",
380 payload_len, MAX_FRAME_PAYLOAD
381 )));
382 }
383
384 let payload_timeout = timeout.max(Duration::from_secs(30));
390 stream.set_read_timeout(Some(payload_timeout)).ok();
391
392 let mut payload = vec![0u8; payload_len];
393 stream
394 .read_exact(&mut payload)
395 .map_err(|e| RepError::NetworkError(e.to_string()))?;
396
397 Ok(Some(payload))
398 }
399
400 fn close(&self) -> Result<()> {
402 self.open.store(false, Ordering::SeqCst);
403 let stream = self.stream.lock();
404 stream
405 .shutdown(std::net::Shutdown::Both)
406 .map_err(|e| RepError::NetworkError(e.to_string()))
407 }
408
409 fn is_open(&self) -> bool {
410 self.open.load(Ordering::SeqCst)
411 }
412}
413
414pub struct TcpChannelListener {
424 listener: TcpListener,
425}
426
427impl TcpChannelListener {
428 pub fn bind(addr: SocketAddr) -> Result<Self> {
430 let listener = TcpListener::bind(addr)
431 .map_err(|e| RepError::NetworkError(e.to_string()))?;
432 Ok(Self { listener })
433 }
434
435 pub fn local_addr(&self) -> Result<SocketAddr> {
452 self.listener
453 .local_addr()
454 .map_err(|e| RepError::NetworkError(e.to_string()))
455 }
456
457 pub fn accept(&self) -> Result<TcpChannel> {
461 let (stream, _peer) = self
462 .listener
463 .accept()
464 .map_err(|e| RepError::NetworkError(e.to_string()))?;
465 Ok(TcpChannel::new(stream))
466 }
467
468 pub fn set_accept_timeout(&self, timeout: Option<Duration>) -> Result<()> {
473 #[cfg(unix)]
474 {
475 use std::os::fd::AsRawFd;
476 let fd = self.listener.as_raw_fd();
477 let tv = match timeout {
478 Some(d) => libc::timeval {
479 tv_sec: d.as_secs() as libc::time_t,
480 tv_usec: d.subsec_micros() as libc::suseconds_t,
481 },
482 None => libc::timeval { tv_sec: 0, tv_usec: 0 },
483 };
484 let rc = unsafe {
488 libc::setsockopt(
489 fd,
490 libc::SOL_SOCKET,
491 libc::SO_RCVTIMEO,
492 &tv as *const _ as *const libc::c_void,
493 std::mem::size_of::<libc::timeval>() as libc::socklen_t,
494 )
495 };
496 if rc != 0 {
497 return Err(RepError::NetworkError(
498 std::io::Error::last_os_error().to_string(),
499 ));
500 }
501 }
502 #[cfg(not(unix))]
503 {
504 let _ = timeout;
505 }
506 Ok(())
507 }
508}
509
510#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
515use crate::tls::TlsConfig;
516
517#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
522trait TlsStreamOps: Send + 'static {
523 fn read_exact_buf(&mut self, buf: &mut [u8]) -> std::io::Result<()>;
524 fn write_all_buf(&mut self, buf: &[u8]) -> std::io::Result<()>;
525 fn flush_buf(&mut self) -> std::io::Result<()>;
526 fn set_read_timeout_inner(
527 &mut self,
528 dur: Option<Duration>,
529 ) -> std::io::Result<()>;
530 fn set_write_timeout_inner(
531 &mut self,
532 dur: Option<Duration>,
533 ) -> std::io::Result<()>;
534 fn shutdown_inner(&self) -> std::io::Result<()>;
535}
536
537#[cfg(feature = "tls-rustls")]
540impl TlsStreamOps for rustls::StreamOwned<rustls::ServerConnection, TcpStream> {
541 fn read_exact_buf(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
542 IoRead::read_exact(self, buf)
543 }
544 fn write_all_buf(&mut self, buf: &[u8]) -> std::io::Result<()> {
545 IoWrite::write_all(self, buf)
546 }
547 fn flush_buf(&mut self) -> std::io::Result<()> {
548 IoWrite::flush(self)
549 }
550 fn set_read_timeout_inner(
551 &mut self,
552 dur: Option<Duration>,
553 ) -> std::io::Result<()> {
554 self.sock.set_read_timeout(dur)
555 }
556 fn set_write_timeout_inner(
557 &mut self,
558 dur: Option<Duration>,
559 ) -> std::io::Result<()> {
560 self.sock.set_write_timeout(dur)
561 }
562 fn shutdown_inner(&self) -> std::io::Result<()> {
563 self.sock.shutdown(std::net::Shutdown::Both)
564 }
565}
566
567#[cfg(feature = "tls-rustls")]
568impl TlsStreamOps for rustls::StreamOwned<rustls::ClientConnection, TcpStream> {
569 fn read_exact_buf(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
570 IoRead::read_exact(self, buf)
571 }
572 fn write_all_buf(&mut self, buf: &[u8]) -> std::io::Result<()> {
573 IoWrite::write_all(self, buf)
574 }
575 fn flush_buf(&mut self) -> std::io::Result<()> {
576 IoWrite::flush(self)
577 }
578 fn set_read_timeout_inner(
579 &mut self,
580 dur: Option<Duration>,
581 ) -> std::io::Result<()> {
582 self.sock.set_read_timeout(dur)
583 }
584 fn set_write_timeout_inner(
585 &mut self,
586 dur: Option<Duration>,
587 ) -> std::io::Result<()> {
588 self.sock.set_write_timeout(dur)
589 }
590 fn shutdown_inner(&self) -> std::io::Result<()> {
591 self.sock.shutdown(std::net::Shutdown::Both)
592 }
593}
594
595#[cfg(feature = "tls-native")]
598impl TlsStreamOps for native_tls::TlsStream<TcpStream> {
599 fn read_exact_buf(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
600 IoRead::read_exact(self, buf)
601 }
602 fn write_all_buf(&mut self, buf: &[u8]) -> std::io::Result<()> {
603 IoWrite::write_all(self, buf)
604 }
605 fn flush_buf(&mut self) -> std::io::Result<()> {
606 IoWrite::flush(self)
607 }
608 fn set_read_timeout_inner(
609 &mut self,
610 dur: Option<Duration>,
611 ) -> std::io::Result<()> {
612 self.get_ref().set_read_timeout(dur)
613 }
614 fn set_write_timeout_inner(
615 &mut self,
616 dur: Option<Duration>,
617 ) -> std::io::Result<()> {
618 self.get_ref().set_write_timeout(dur)
619 }
620 fn shutdown_inner(&self) -> std::io::Result<()> {
621 self.get_ref().shutdown(std::net::Shutdown::Both)
622 }
623}
624
625#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
651pub struct TlsTcpChannel {
652 stream: Arc<std::sync::Mutex<Box<dyn TlsStreamOps>>>,
653 open: AtomicBool,
654}
655
656#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
657impl TlsTcpChannel {
658 fn wrap(stream: Box<dyn TlsStreamOps>) -> Self {
659 Self {
660 stream: Arc::new(std::sync::Mutex::new(stream)),
661 open: AtomicBool::new(true),
662 }
663 }
664
665 pub fn connect_with_tls(addr: SocketAddr, tls: &TlsConfig) -> Result<Self> {
670 #[cfg(feature = "tls-rustls")]
671 {
672 return Self::connect_rustls(addr, tls);
673 }
674 #[cfg(all(feature = "tls-native", not(feature = "tls-rustls")))]
675 {
676 return Self::connect_native(addr, tls);
677 }
678 #[allow(unreachable_code)]
679 Err(RepError::NetworkError("no TLS feature enabled".into()))
680 }
681
682 #[cfg(feature = "tls-rustls")]
683 fn connect_rustls(addr: SocketAddr, tls: &TlsConfig) -> Result<Self> {
684 use rustls::pki_types::ServerName;
685 let cfg = tls.to_rustls_client_config()?;
686 let server_name = ServerName::try_from(tls.server_name.clone())
687 .map_err(|e| {
688 RepError::NetworkError(format!("invalid server name: {e}"))
689 })?;
690 let tcp = TcpStream::connect_timeout(&addr, Duration::from_secs(30))
691 .map_err(|e| RepError::NetworkError(e.to_string()))?;
692 let conn =
693 rustls::ClientConnection::new(cfg, server_name).map_err(|e| {
694 RepError::NetworkError(format!("TLS client init: {e}"))
695 })?;
696 let stream = rustls::StreamOwned::new(conn, tcp);
697 Ok(Self::wrap(Box::new(stream)))
698 }
699
700 #[cfg(feature = "tls-native")]
701 fn connect_native(addr: SocketAddr, tls: &TlsConfig) -> Result<Self> {
702 let connector = tls.to_native_connector()?;
703 let tcp = TcpStream::connect_timeout(&addr, Duration::from_secs(30))
704 .map_err(|e| RepError::NetworkError(e.to_string()))?;
705 let stream = connector.connect(&tls.server_name, tcp).map_err(|e| {
706 RepError::NetworkError(format!("TLS handshake: {e}"))
707 })?;
708 Ok(Self::wrap(Box::new(stream)))
709 }
710}
711
712#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
713impl Channel for TlsTcpChannel {
714 fn send(&self, data: &[u8]) -> Result<()> {
715 if !self.is_open() {
716 return Err(RepError::ChannelClosed(
717 "TlsTcpChannel is closed".into(),
718 ));
719 }
720 let len = data.len() as u32;
721 let mut s = self.stream.lock().map_err(|_| {
722 RepError::NetworkError("TLS stream lock poisoned".into())
723 })?;
724 s.set_write_timeout_inner(Some(Duration::from_secs(30)))
725 .map_err(|e| RepError::NetworkError(e.to_string()))?;
726 s.write_all_buf(&len.to_le_bytes())
727 .map_err(|e| RepError::NetworkError(e.to_string()))?;
728 s.write_all_buf(data)
729 .map_err(|e| RepError::NetworkError(e.to_string()))?;
730 s.flush_buf().map_err(|e| RepError::NetworkError(e.to_string()))?;
731 Ok(())
732 }
733
734 fn receive(&self, timeout: Duration) -> Result<Option<Vec<u8>>> {
735 if !self.is_open() {
736 return Err(RepError::ChannelClosed(
737 "TlsTcpChannel is closed".into(),
738 ));
739 }
740 let mut s = self.stream.lock().map_err(|_| {
741 RepError::NetworkError("TLS stream lock poisoned".into())
742 })?;
743 s.set_read_timeout_inner(Some(timeout))
744 .map_err(|e| RepError::NetworkError(e.to_string()))?;
745 let mut len_buf = [0u8; 4];
746 match s.read_exact_buf(&mut len_buf) {
747 Ok(()) => {}
748 Err(e) => {
749 if e.kind() == std::io::ErrorKind::WouldBlock
750 || e.kind() == std::io::ErrorKind::TimedOut
751 {
752 return Ok(None);
753 }
754 if e.kind() == std::io::ErrorKind::UnexpectedEof {
755 return Err(RepError::ChannelClosed(
756 "connection closed by peer".into(),
757 ));
758 }
759 return Err(RepError::NetworkError(e.to_string()));
760 }
761 }
762 let payload_len = u32::from_le_bytes(len_buf) as usize;
763 if payload_len > MAX_FRAME_PAYLOAD {
764 return Err(RepError::ProtocolError(format!(
765 "frame payload too large: {} > {}",
766 payload_len, MAX_FRAME_PAYLOAD
767 )));
768 }
769 let payload_timeout = timeout.max(Duration::from_secs(30));
770 s.set_read_timeout_inner(Some(payload_timeout))
771 .map_err(|e| RepError::NetworkError(e.to_string()))?;
772 let mut payload = vec![0u8; payload_len];
773 s.read_exact_buf(&mut payload)
774 .map_err(|e| RepError::NetworkError(e.to_string()))?;
775 Ok(Some(payload))
776 }
777
778 fn close(&self) -> Result<()> {
779 self.open.store(false, Ordering::SeqCst);
780 let s = self.stream.lock().map_err(|_| {
781 RepError::NetworkError("TLS stream lock poisoned".into())
782 })?;
783 s.shutdown_inner().map_err(|e| RepError::NetworkError(e.to_string()))
784 }
785
786 fn is_open(&self) -> bool {
787 self.open.load(Ordering::SeqCst)
788 }
789}
790
791#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
794enum TlsAcceptorImpl {
795 #[cfg(feature = "tls-rustls")]
796 Rustls(std::sync::Arc<rustls::ServerConfig>),
797 #[cfg(feature = "tls-native")]
798 Native(native_tls::TlsAcceptor),
799}
800
801#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
807pub struct TlsTcpChannelListener {
808 listener: TcpListener,
809 acceptor: TlsAcceptorImpl,
810}
811
812#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
813impl TlsTcpChannelListener {
814 pub fn bind_with_tls(addr: SocketAddr, tls: &TlsConfig) -> Result<Self> {
819 let listener = TcpListener::bind(addr)
820 .map_err(|e| RepError::NetworkError(e.to_string()))?;
821 #[cfg(feature = "tls-rustls")]
822 let acceptor = {
823 let cfg = tls.to_rustls_server_config()?;
824 TlsAcceptorImpl::Rustls(cfg)
825 };
826 #[cfg(all(feature = "tls-native", not(feature = "tls-rustls")))]
827 let acceptor = {
828 let a = tls.to_native_acceptor()?;
829 TlsAcceptorImpl::Native(a)
830 };
831 Ok(Self { listener, acceptor })
832 }
833
834 #[cfg(feature = "tls-rustls")]
840 pub fn bind_with_tls_and_allowlist(
841 addr: SocketAddr,
842 tls: &TlsConfig,
843 allowlist: crate::auth::PeerAllowlist,
844 ) -> Result<Self> {
845 let listener = TcpListener::bind(addr)
846 .map_err(|e| RepError::NetworkError(e.to_string()))?;
847 let cfg = tls.to_rustls_server_config_with_allowlist(allowlist)?;
848 Ok(Self { listener, acceptor: TlsAcceptorImpl::Rustls(cfg) })
849 }
850
851 pub fn local_addr(&self) -> Result<SocketAddr> {
853 self.listener
854 .local_addr()
855 .map_err(|e| RepError::NetworkError(e.to_string()))
856 }
857
858 pub fn accept(&self) -> Result<TlsTcpChannel> {
860 let (tcp, _peer) = self
861 .listener
862 .accept()
863 .map_err(|e| RepError::NetworkError(e.to_string()))?;
864 match &self.acceptor {
865 #[cfg(feature = "tls-rustls")]
866 TlsAcceptorImpl::Rustls(cfg) => {
867 let conn = rustls::ServerConnection::new(Arc::clone(cfg))
868 .map_err(|e| {
869 RepError::NetworkError(format!("TLS server init: {e}"))
870 })?;
871 let stream = rustls::StreamOwned::new(conn, tcp);
872 Ok(TlsTcpChannel::wrap(Box::new(stream)))
873 }
874 #[cfg(feature = "tls-native")]
875 TlsAcceptorImpl::Native(acceptor) => {
876 let stream = acceptor.accept(tcp).map_err(|e| {
877 RepError::NetworkError(format!("TLS handshake: {e}"))
878 })?;
879 Ok(TlsTcpChannel::wrap(Box::new(stream)))
880 }
881 }
882 }
883}
884
885#[cfg(test)]
886mod tests {
887 use super::*;
888 use std::time::Duration;
889
890 #[test]
891 fn test_send_receive_basic() {
892 let pair = LocalChannelPair::new();
893 let msg = b"hello world";
894 pair.channel_a.send(msg).unwrap();
895 let received = pair.channel_b.receive(Duration::from_secs(1)).unwrap();
896 assert_eq!(received, Some(msg.to_vec()));
897 }
898
899 #[test]
900 fn test_bidirectional() {
901 let pair = LocalChannelPair::new();
902
903 pair.channel_a.send(b"from a").unwrap();
904 pair.channel_b.send(b"from b").unwrap();
905
906 let recv_b = pair.channel_b.receive(Duration::from_secs(1)).unwrap();
907 assert_eq!(recv_b, Some(b"from a".to_vec()));
908
909 let recv_a = pair.channel_a.receive(Duration::from_secs(1)).unwrap();
910 assert_eq!(recv_a, Some(b"from b".to_vec()));
911 }
912
913 #[test]
914 fn test_multiple_messages_fifo() {
915 let pair = LocalChannelPair::new();
916 pair.channel_a.send(b"first").unwrap();
917 pair.channel_a.send(b"second").unwrap();
918 pair.channel_a.send(b"third").unwrap();
919
920 assert_eq!(
921 pair.channel_b.receive(Duration::from_secs(1)).unwrap(),
922 Some(b"first".to_vec())
923 );
924 assert_eq!(
925 pair.channel_b.receive(Duration::from_secs(1)).unwrap(),
926 Some(b"second".to_vec())
927 );
928 assert_eq!(
929 pair.channel_b.receive(Duration::from_secs(1)).unwrap(),
930 Some(b"third".to_vec())
931 );
932 }
933
934 #[test]
935 fn test_receive_timeout_empty_queue() {
936 let pair = LocalChannelPair::new();
937 let result = pair.channel_b.receive(Duration::from_millis(50)).unwrap();
938 assert_eq!(result, None);
939 }
940
941 #[test]
942 fn test_send_after_close_fails() {
943 let pair = LocalChannelPair::new();
944 pair.channel_a.close().unwrap();
945 let result = pair.channel_a.send(b"should fail");
946 assert!(result.is_err());
947 }
948
949 #[test]
950 fn test_receive_after_close_fails() {
951 let pair = LocalChannelPair::new();
952 pair.channel_b.close().unwrap();
953 let result = pair.channel_b.receive(Duration::from_millis(10));
954 assert!(result.is_err());
955 }
956
957 #[test]
958 fn test_is_open() {
959 let pair = LocalChannelPair::new();
960 assert!(pair.channel_a.is_open());
961 assert!(pair.channel_b.is_open());
962
963 pair.channel_a.close().unwrap();
964 assert!(!pair.channel_a.is_open());
965 assert!(pair.channel_b.is_open());
967 }
968
969 #[test]
970 fn test_empty_message() {
971 let pair = LocalChannelPair::new();
972 pair.channel_a.send(b"").unwrap();
973 let received = pair.channel_b.receive(Duration::from_secs(1)).unwrap();
974 assert_eq!(received, Some(vec![]));
975 }
976
977 #[test]
978 fn test_large_message() {
979 let pair = LocalChannelPair::new();
980 let large = vec![0xABu8; 1024 * 1024]; pair.channel_a.send(&large).unwrap();
982 let received = pair.channel_b.receive(Duration::from_secs(1)).unwrap();
983 assert_eq!(received, Some(large));
984 }
985
986 #[test]
987 fn test_concurrent_send_receive() {
988 let pair = LocalChannelPair::new();
989 let queue_send = Arc::clone(&pair.channel_a.send_queue);
991 let _queue_recv = Arc::clone(&pair.channel_b.recv_queue);
992
993 let _channel_b_send = Arc::new(ChannelQueue::new());
994 let _channel_b_recv = Arc::clone(&queue_send); std::thread::scope(|s| {
998 let a = &pair.channel_a;
999 let b = &pair.channel_b;
1000
1001 let handle = s.spawn(|| {
1002 let msg = b.receive(Duration::from_secs(5)).unwrap();
1003 assert_eq!(msg, Some(b"concurrent".to_vec()));
1004 b.send(b"ack").unwrap();
1005 });
1006
1007 a.send(b"concurrent").unwrap();
1008 let ack = a.receive(Duration::from_secs(5)).unwrap();
1009 assert_eq!(ack, Some(b"ack".to_vec()));
1010 handle.join().unwrap();
1011 });
1012 }
1013
1014 #[test]
1015 fn test_default_trait() {
1016 let pair = LocalChannelPair::default();
1017 assert!(pair.channel_a.is_open());
1018 assert!(pair.channel_b.is_open());
1019 }
1020
1021 #[test]
1026 fn test_tcp_channel_send_receive() {
1027 use std::net::TcpListener;
1028
1029 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1030 let addr = listener.local_addr().unwrap();
1031
1032 let handle = std::thread::spawn(move || {
1033 let (stream, _) = listener.accept().unwrap();
1034 let ch = TcpChannel::new(stream);
1035 let msg = ch.receive(Duration::from_secs(5)).unwrap();
1036 assert_eq!(msg, Some(b"hello tcp".to_vec()));
1037 ch.send(b"world").unwrap();
1038 });
1039
1040 let client = TcpChannel::connect(addr).unwrap();
1041 client.send(b"hello tcp").unwrap();
1042 let reply = client.receive(Duration::from_secs(5)).unwrap();
1043 assert_eq!(reply, Some(b"world".to_vec()));
1044
1045 handle.join().unwrap();
1046 }
1047
1048 #[test]
1049 fn test_tcp_channel_multiple_messages() {
1050 use std::net::TcpListener;
1051
1052 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1053 let addr = listener.local_addr().unwrap();
1054
1055 let handle = std::thread::spawn(move || {
1056 let (stream, _) = listener.accept().unwrap();
1057 let ch = TcpChannel::new(stream);
1058 for i in 0u8..5 {
1059 let msg = ch.receive(Duration::from_secs(5)).unwrap().unwrap();
1060 assert_eq!(msg, vec![i]);
1061 }
1062 });
1063
1064 let client = TcpChannel::connect(addr).unwrap();
1065 for i in 0u8..5 {
1066 client.send(&[i]).unwrap();
1067 }
1068 handle.join().unwrap();
1069 }
1070
1071 #[test]
1072 fn test_tcp_channel_receive_timeout() {
1073 use std::net::TcpListener;
1074
1075 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1076 let addr = listener.local_addr().unwrap();
1077
1078 let handle = std::thread::spawn(move || {
1080 let (_stream, _) = listener.accept().unwrap();
1081 std::thread::sleep(Duration::from_secs(2));
1082 });
1083
1084 let client = TcpChannel::connect(addr).unwrap();
1085 let result = client.receive(Duration::from_millis(100)).unwrap();
1086 assert_eq!(result, None, "expected timeout → None");
1087
1088 handle.join().unwrap();
1089 }
1090
1091 #[test]
1092 fn test_tcp_channel_is_open_and_close() {
1093 use std::net::TcpListener;
1094
1095 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1096 let addr = listener.local_addr().unwrap();
1097
1098 let handle = std::thread::spawn(move || {
1099 let (_stream, _) = listener.accept().unwrap();
1100 std::thread::sleep(Duration::from_millis(200));
1101 });
1102
1103 let client = TcpChannel::connect(addr).unwrap();
1104 assert!(client.is_open());
1105 client.close().unwrap();
1106 assert!(!client.is_open());
1107
1108 handle.join().unwrap();
1109 }
1110
1111 #[test]
1112 fn test_tcp_channel_large_payload() {
1113 use std::net::TcpListener;
1114
1115 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1116 let addr = listener.local_addr().unwrap();
1117 let payload: Vec<u8> = (0..65536).map(|i| (i % 256) as u8).collect();
1118 let expected = payload.clone();
1119
1120 let handle = std::thread::spawn(move || {
1121 let (stream, _) = listener.accept().unwrap();
1122 let ch = TcpChannel::new(stream);
1123 let msg = ch.receive(Duration::from_secs(5)).unwrap().unwrap();
1124 assert_eq!(msg, expected);
1125 });
1126
1127 let client = TcpChannel::connect(addr).unwrap();
1128 client.send(&payload).unwrap();
1129 handle.join().unwrap();
1130 }
1131
1132 #[test]
1133 fn test_tcp_channel_listener_bind_and_accept() {
1134 let listener =
1135 TcpChannelListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
1136 let addr = listener.local_addr().unwrap();
1137
1138 let handle = std::thread::spawn(move || {
1139 let ch = listener.accept().unwrap();
1140 let msg = ch.receive(Duration::from_secs(5)).unwrap();
1141 assert_eq!(msg, Some(b"ping".to_vec()));
1142 });
1143
1144 let client = TcpChannel::connect(addr).unwrap();
1145 client.send(b"ping").unwrap();
1146 handle.join().unwrap();
1147 }
1148
1149 #[test]
1154 fn test_tcp_channel_rejects_oversize_frame() {
1155 use std::net::TcpListener;
1156
1157 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1158 let addr = listener.local_addr().unwrap();
1159
1160 let handle = std::thread::spawn(move || {
1164 let (mut stream, _) = listener.accept().unwrap();
1165 let oversized = (crate::net::channel::MAX_FRAME_PAYLOAD as u32)
1166 .saturating_add(1);
1167 stream.write_all(&oversized.to_le_bytes()).unwrap();
1168 std::thread::sleep(Duration::from_millis(200));
1170 });
1171
1172 let client = TcpChannel::connect(addr).unwrap();
1173 let err = client
1174 .receive(Duration::from_secs(5))
1175 .expect_err("oversize frame must be rejected");
1176 match err {
1177 RepError::ProtocolError(msg) => {
1178 assert!(
1179 msg.contains("frame payload too large"),
1180 "unexpected protocol-error message: {}",
1181 msg
1182 );
1183 }
1184 other => panic!("expected ProtocolError, got {:?}", other),
1185 }
1186
1187 handle.join().unwrap();
1188 }
1189
1190 #[cfg(feature = "tls-rustls")]
1195 mod tls_tests {
1196 use super::*;
1197 use crate::tls::TlsConfig;
1198
1199 #[test]
1200 fn test_tls_tcp_send_receive() {
1201 let tls = TlsConfig::insecure("localhost");
1202 let listener = TlsTcpChannelListener::bind_with_tls(
1203 "127.0.0.1:0".parse().unwrap(),
1204 &tls,
1205 )
1206 .unwrap();
1207 let addr = listener.local_addr().unwrap();
1208
1209 let handle = std::thread::spawn(move || {
1210 let ch = listener.accept().unwrap();
1211 let msg = ch.receive(Duration::from_secs(5)).unwrap();
1212 assert_eq!(msg, Some(b"hello tls".to_vec()));
1213 ch.send(b"world tls").unwrap();
1214 });
1215
1216 let client = TlsTcpChannel::connect_with_tls(addr, &tls).unwrap();
1217 client.send(b"hello tls").unwrap();
1218 let reply = client.receive(Duration::from_secs(5)).unwrap();
1219 assert_eq!(reply, Some(b"world tls".to_vec()));
1220
1221 handle.join().unwrap();
1222 }
1223
1224 #[test]
1225 fn test_tls_tcp_multiple_messages() {
1226 let tls = TlsConfig::insecure("localhost");
1227 let listener = TlsTcpChannelListener::bind_with_tls(
1228 "127.0.0.1:0".parse().unwrap(),
1229 &tls,
1230 )
1231 .unwrap();
1232 let addr = listener.local_addr().unwrap();
1233
1234 let handle = std::thread::spawn(move || {
1235 let ch = listener.accept().unwrap();
1236 for i in 0u8..4 {
1237 let msg =
1238 ch.receive(Duration::from_secs(5)).unwrap().unwrap();
1239 assert_eq!(msg, vec![i]);
1240 }
1241 });
1242
1243 let client = TlsTcpChannel::connect_with_tls(addr, &tls).unwrap();
1244 for i in 0u8..4 {
1245 client.send(&[i]).unwrap();
1246 }
1247 handle.join().unwrap();
1248 }
1249
1250 #[test]
1251 fn test_tls_tcp_large_payload() {
1252 let tls = TlsConfig::insecure("localhost");
1253 let listener = TlsTcpChannelListener::bind_with_tls(
1254 "127.0.0.1:0".parse().unwrap(),
1255 &tls,
1256 )
1257 .unwrap();
1258 let addr = listener.local_addr().unwrap();
1259 let payload: Vec<u8> =
1260 (0..65536).map(|i| (i % 256) as u8).collect();
1261 let expected = payload.clone();
1262
1263 let handle = std::thread::spawn(move || {
1264 let ch = listener.accept().unwrap();
1265 let msg = ch.receive(Duration::from_secs(5)).unwrap().unwrap();
1266 assert_eq!(msg, expected);
1267 });
1268
1269 let client = TlsTcpChannel::connect_with_tls(addr, &tls).unwrap();
1270 client.send(&payload).unwrap();
1271 handle.join().unwrap();
1272 }
1273
1274 #[test]
1275 fn test_tls_tcp_receive_timeout() {
1276 let tls = TlsConfig::insecure("localhost");
1277 let listener = TlsTcpChannelListener::bind_with_tls(
1278 "127.0.0.1:0".parse().unwrap(),
1279 &tls,
1280 )
1281 .unwrap();
1282 let addr = listener.local_addr().unwrap();
1283
1284 let handle = std::thread::spawn(move || {
1286 let _ch = listener.accept().unwrap();
1287 std::thread::sleep(Duration::from_secs(2));
1288 });
1289
1290 let client = TlsTcpChannel::connect_with_tls(addr, &tls).unwrap();
1291 let result = client.receive(Duration::from_millis(500)).unwrap();
1294 assert_eq!(result, None, "expected timeout → None");
1295
1296 handle.join().unwrap();
1297 }
1298
1299 #[test]
1300 fn test_tls_tcp_close() {
1301 let tls = TlsConfig::insecure("localhost");
1302 let listener = TlsTcpChannelListener::bind_with_tls(
1303 "127.0.0.1:0".parse().unwrap(),
1304 &tls,
1305 )
1306 .unwrap();
1307 let addr = listener.local_addr().unwrap();
1308
1309 let handle = std::thread::spawn(move || {
1310 let _ch = listener.accept().unwrap();
1311 std::thread::sleep(Duration::from_millis(200));
1312 });
1313
1314 let client = TlsTcpChannel::connect_with_tls(addr, &tls).unwrap();
1315 assert!(client.is_open());
1316 client.close().unwrap();
1317 assert!(!client.is_open());
1318
1319 handle.join().unwrap();
1320 }
1321
1322 #[test]
1325 fn test_tls_tcp_rejects_oversize_frame() {
1326 let tls = TlsConfig::insecure("localhost");
1327 let listener = TlsTcpChannelListener::bind_with_tls(
1328 "127.0.0.1:0".parse().unwrap(),
1329 &tls,
1330 )
1331 .unwrap();
1332 let addr = listener.local_addr().unwrap();
1333
1334 let handle = std::thread::spawn(move || {
1339 let ch = listener.accept().unwrap();
1340 let oversized =
1341 vec![0u8; crate::net::channel::MAX_FRAME_PAYLOAD + 1];
1342 let _ = ch.send(&oversized);
1343 });
1344
1345 let client = TlsTcpChannel::connect_with_tls(addr, &tls).unwrap();
1346 let result = client.receive(Duration::from_secs(10));
1347 let _ = client.close();
1351 let err = result.expect_err("oversize TLS frame must be rejected");
1352 match err {
1353 RepError::ProtocolError(msg) => {
1354 assert!(
1355 msg.contains("frame payload too large"),
1356 "unexpected protocol-error message: {}",
1357 msg
1358 );
1359 }
1360 other => panic!("expected ProtocolError, got {:?}", other),
1361 }
1362 let _ = handle.join();
1363 }
1364 }
1365}