nntp_proxy/
stream.rs

1//! Stream abstraction for supporting multiple connection types
2//!
3//! This module provides abstractions for handling different stream types (TCP, TLS, etc.)
4//! in a unified way. This is preparation for adding SSL/TLS support to backend connections.
5
6use std::io;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10use tokio::net::TcpStream;
11use tokio_rustls::client::TlsStream;
12
13/// Trait for async streams that can be used for NNTP connections
14///
15/// This trait is automatically implemented for any type that implements
16/// AsyncRead + AsyncWrite + Unpin + Send, making it easy to support
17/// different connection types (TCP, TLS, etc.).
18pub trait AsyncStream: AsyncRead + AsyncWrite + Unpin + Send {}
19
20// Blanket implementation for all types that meet the requirements
21impl<T> AsyncStream for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
22
23/// Unified stream type that can represent different connection types
24///
25/// This enum allows the proxy to handle both plain TCP and TLS connections
26/// through a single type, avoiding the need for trait objects and their associated
27/// heap allocation overhead.
28#[derive(Debug)]
29pub enum ConnectionStream {
30    /// Plain TCP connection
31    Plain(TcpStream),
32    /// TLS-encrypted connection
33    Tls(Box<TlsStream<TcpStream>>),
34}
35
36impl ConnectionStream {
37    /// Create a new plain TCP connection stream
38    pub fn plain(stream: TcpStream) -> Self {
39        Self::Plain(stream)
40    }
41    
42    /// Create a new TLS-encrypted connection stream
43    pub fn tls(stream: TlsStream<TcpStream>) -> Self {
44        Self::Tls(Box::new(stream))
45    }
46    
47    /// Returns the connection type as a string for logging/debugging
48    pub fn connection_type(&self) -> &'static str {
49        match self {
50            Self::Plain(_) => "TCP",
51            Self::Tls(_) => "TLS",
52        }
53    }
54    
55    /// Returns true if this connection uses encryption (TLS/SSL)
56    pub fn is_encrypted(&self) -> bool {
57        matches!(self, Self::Tls(_))
58    }
59    
60    /// Returns true if this connection is unencrypted (plain TCP)
61    pub fn is_unencrypted(&self) -> bool {
62        matches!(self, Self::Plain(_))
63    }
64
65    /// Get a reference to the underlying TCP stream (if plain TCP)
66    ///
67    /// Returns None for TLS streams, as the TCP stream is wrapped.
68    /// Useful for socket optimization that requires direct TCP access.
69    pub fn as_tcp_stream(&self) -> Option<&TcpStream> {
70        match self {
71            Self::Plain(tcp) => Some(tcp),
72            Self::Tls(_) => None,
73        }
74    }
75
76    /// Get a mutable reference to the underlying TCP stream (if plain TCP)
77    pub fn as_tcp_stream_mut(&mut self) -> Option<&mut TcpStream> {
78        match self {
79            Self::Plain(tcp) => Some(tcp),
80            Self::Tls(_) => None,
81        }
82    }
83
84    /// Get a reference to the TLS stream (if TLS connection)
85    pub fn as_tls_stream(&self) -> Option<&TlsStream<TcpStream>> {
86        match self {
87            Self::Tls(tls) => Some(tls.as_ref()),
88            Self::Plain(_) => None,
89        }
90    }
91    
92    /// Get a mutable reference to the TLS stream (if TLS connection)
93    pub fn as_tls_stream_mut(&mut self) -> Option<&mut TlsStream<TcpStream>> {
94        match self {
95            Self::Tls(tls) => Some(tls.as_mut()),
96            Self::Plain(_) => None,
97        }
98    }
99    
100    /// Get the underlying TCP stream reference regardless of connection type
101    /// 
102    /// For plain TCP, returns the stream directly.
103    /// For TLS, returns the underlying TCP stream within the TLS wrapper.
104    pub fn underlying_tcp_stream(&self) -> &TcpStream {
105        match self {
106            Self::Plain(tcp) => tcp,
107            Self::Tls(tls) => tls.get_ref().0,
108        }
109    }
110}
111
112impl AsyncRead for ConnectionStream {
113    fn poll_read(
114        mut self: Pin<&mut Self>,
115        cx: &mut Context<'_>,
116        buf: &mut ReadBuf<'_>,
117    ) -> Poll<io::Result<()>> {
118        match &mut *self {
119            Self::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
120            Self::Tls(stream) => Pin::new(stream.as_mut()).poll_read(cx, buf),
121        }
122    }
123}
124
125impl AsyncWrite for ConnectionStream {
126    fn poll_write(
127        mut self: Pin<&mut Self>,
128        cx: &mut Context<'_>,
129        buf: &[u8],
130    ) -> Poll<io::Result<usize>> {
131        match &mut *self {
132            Self::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
133            Self::Tls(stream) => Pin::new(stream.as_mut()).poll_write(cx, buf),
134        }
135    }
136
137    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
138        match &mut *self {
139            Self::Plain(stream) => Pin::new(stream).poll_flush(cx),
140            Self::Tls(stream) => Pin::new(stream.as_mut()).poll_flush(cx),
141        }
142    }
143
144    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
145        match &mut *self {
146            Self::Plain(stream) => Pin::new(stream).poll_shutdown(cx),
147            Self::Tls(stream) => Pin::new(stream.as_mut()).poll_shutdown(cx),
148        }
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use tokio::io::{AsyncReadExt, AsyncWriteExt};
156
157    #[tokio::test]
158    async fn test_connection_stream_plain_tcp() {
159        // Create a listener and client
160        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
161        let addr = listener.local_addr().unwrap();
162
163        let client_handle = tokio::spawn(async move { TcpStream::connect(addr).await.unwrap() });
164
165        let (server_stream, _) = listener.accept().await.unwrap();
166        let client_stream = client_handle.await.unwrap();
167
168        // Wrap in ConnectionStream
169        let mut server_conn = ConnectionStream::plain(server_stream);
170        let mut client_conn = ConnectionStream::plain(client_stream);
171
172        // Test writing and reading
173        client_conn.write_all(b"Hello").await.unwrap();
174
175        let mut buf = [0u8; 5];
176        server_conn.read_exact(&mut buf).await.unwrap();
177        assert_eq!(&buf, b"Hello");
178
179        // Test stream type checking
180        assert!(client_conn.is_unencrypted());
181        assert!(!client_conn.is_encrypted());
182        assert_eq!(client_conn.connection_type(), "TCP");
183        assert!(client_conn.as_tcp_stream().is_some());
184    }
185
186    #[test]
187    fn test_async_stream_trait() {
188        // Verify TcpStream implements AsyncStream
189        fn assert_async_stream<T: AsyncStream>() {}
190        assert_async_stream::<TcpStream>();
191        assert_async_stream::<ConnectionStream>();
192    }
193
194    #[tokio::test]
195    async fn test_connection_stream_tcp_access() {
196        let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
197        let addr = listener.local_addr().unwrap();
198
199        let tcp_stream = std::net::TcpStream::connect(addr).unwrap();
200        tcp_stream.set_nonblocking(true).unwrap();
201        let tokio_stream = TcpStream::from_std(tcp_stream).unwrap();
202
203        let mut conn_stream = ConnectionStream::plain(tokio_stream);
204
205        // Should be able to access underlying TCP stream
206        assert!(conn_stream.as_tcp_stream().is_some());
207        assert!(conn_stream.as_tcp_stream_mut().is_some());
208        
209        // Test new API methods
210        assert!(conn_stream.is_unencrypted());
211        assert!(!conn_stream.is_encrypted());
212        assert_eq!(conn_stream.connection_type(), "TCP");
213        
214        // Test underlying TCP access
215        let _underlying = conn_stream.underlying_tcp_stream();
216    }
217    
218    #[tokio::test]
219    async fn test_connection_type_methods() {
220        // Test that the new API names are more explicit and clear
221        use std::net::TcpListener;
222        
223        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
224        let addr = listener.local_addr().unwrap();
225        
226        let tcp = std::net::TcpStream::connect(addr).unwrap();
227        tcp.set_nonblocking(true).unwrap();
228        let stream = TcpStream::from_std(tcp).unwrap();
229        
230        let conn = ConnectionStream::plain(stream);
231        
232        // New explicit method names
233        assert!(conn.is_unencrypted(), "Plain TCP should be unencrypted");
234        assert!(!conn.is_encrypted(), "Plain TCP should not be encrypted");
235        assert_eq!(conn.connection_type(), "TCP");
236    }
237}