nntp_proxy/
stream.rs

1//! Stream abstraction for supporting multiple connection types
2//!
3//! This module provides abstractions for handling different stream types (TCP, TLS, etc.)
4//! in a unified way. This is preparation for adding SSL/TLS support to backend connections.
5
6use crate::tls::TlsStream;
7use std::io;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11use tokio::net::TcpStream;
12
13/// Trait for async streams that can be used for NNTP connections
14///
15/// This trait is automatically implemented for any type that implements
16/// AsyncRead + AsyncWrite + Unpin + Send, making it easy to support
17/// different connection types (TCP, TLS, etc.).
18pub trait AsyncStream: AsyncRead + AsyncWrite + Unpin + Send {}
19
20// Blanket implementation for all types that meet the requirements
21impl<T> AsyncStream for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
22
23/// Unified stream type that can represent different connection types
24///
25/// This enum allows the proxy to handle both plain TCP and TLS connections
26/// through a single type, avoiding the need for trait objects and their associated
27/// heap allocation overhead.
28#[derive(Debug)]
29pub enum ConnectionStream {
30    /// Plain TCP connection
31    Plain(TcpStream),
32    /// TLS-encrypted connection
33    Tls(Box<TlsStream<TcpStream>>),
34}
35
36impl ConnectionStream {
37    /// Create a new plain TCP connection stream
38    pub fn plain(stream: TcpStream) -> Self {
39        Self::Plain(stream)
40    }
41
42    /// Create a new TLS-encrypted connection stream
43    pub fn tls(stream: TlsStream<TcpStream>) -> Self {
44        Self::Tls(Box::new(stream))
45    }
46
47    /// Returns the connection type as a string for logging/debugging
48    #[must_use]
49    pub fn connection_type(&self) -> &'static str {
50        match self {
51            Self::Plain(_) => "TCP",
52            Self::Tls(_) => "TLS",
53        }
54    }
55
56    /// Returns true if this connection uses encryption (TLS/SSL)
57    #[inline]
58    #[must_use]
59    pub const fn is_encrypted(&self) -> bool {
60        matches!(self, Self::Tls(_))
61    }
62
63    /// Returns true if this connection is unencrypted (plain TCP)
64    #[inline]
65    #[must_use]
66    pub const fn is_unencrypted(&self) -> bool {
67        matches!(self, Self::Plain(_))
68    }
69
70    /// Get a reference to the underlying TCP stream (if plain TCP)
71    ///
72    /// Returns None for TLS streams, as the TCP stream is wrapped.
73    /// Useful for socket optimization that requires direct TCP access.
74    #[must_use]
75    pub fn as_tcp_stream(&self) -> Option<&TcpStream> {
76        match self {
77            Self::Plain(tcp) => Some(tcp),
78            Self::Tls(_) => None,
79        }
80    }
81
82    /// Get a mutable reference to the underlying TCP stream (if plain TCP)
83    pub fn as_tcp_stream_mut(&mut self) -> Option<&mut TcpStream> {
84        match self {
85            Self::Plain(tcp) => Some(tcp),
86            Self::Tls(_) => None,
87        }
88    }
89
90    /// Get a reference to the TLS stream (if TLS connection)
91    #[must_use]
92    pub fn as_tls_stream(&self) -> Option<&TlsStream<TcpStream>> {
93        match self {
94            Self::Tls(tls) => Some(tls.as_ref()),
95            Self::Plain(_) => None,
96        }
97    }
98
99    /// Get a mutable reference to the TLS stream (if TLS connection)
100    pub fn as_tls_stream_mut(&mut self) -> Option<&mut TlsStream<TcpStream>> {
101        match self {
102            Self::Tls(tls) => Some(tls.as_mut()),
103            Self::Plain(_) => None,
104        }
105    }
106
107    /// Get the underlying TCP stream reference regardless of connection type
108    ///
109    /// For plain TCP, returns the stream directly.
110    /// For TLS, returns the underlying TCP stream within the TLS wrapper.
111    #[must_use]
112    pub fn underlying_tcp_stream(&self) -> &TcpStream {
113        match self {
114            Self::Plain(tcp) => tcp,
115            Self::Tls(tls) => tls.get_ref().0,
116        }
117    }
118}
119
120impl AsyncRead for ConnectionStream {
121    fn poll_read(
122        mut self: Pin<&mut Self>,
123        cx: &mut Context<'_>,
124        buf: &mut ReadBuf<'_>,
125    ) -> Poll<io::Result<()>> {
126        match &mut *self {
127            Self::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
128            Self::Tls(stream) => Pin::new(stream.as_mut()).poll_read(cx, buf),
129        }
130    }
131}
132
133impl AsyncWrite for ConnectionStream {
134    fn poll_write(
135        mut self: Pin<&mut Self>,
136        cx: &mut Context<'_>,
137        buf: &[u8],
138    ) -> Poll<io::Result<usize>> {
139        match &mut *self {
140            Self::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
141            Self::Tls(stream) => Pin::new(stream.as_mut()).poll_write(cx, buf),
142        }
143    }
144
145    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
146        match &mut *self {
147            Self::Plain(stream) => Pin::new(stream).poll_flush(cx),
148            Self::Tls(stream) => Pin::new(stream.as_mut()).poll_flush(cx),
149        }
150    }
151
152    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
153        match &mut *self {
154            Self::Plain(stream) => Pin::new(stream).poll_shutdown(cx),
155            Self::Tls(stream) => Pin::new(stream.as_mut()).poll_shutdown(cx),
156        }
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use tokio::io::{AsyncReadExt, AsyncWriteExt};
164
165    #[tokio::test]
166    async fn test_connection_stream_plain_tcp() {
167        // Create a listener and client
168        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
169        let addr = listener.local_addr().unwrap();
170
171        let client_handle = tokio::spawn(async move { TcpStream::connect(addr).await.unwrap() });
172
173        let (server_stream, _) = listener.accept().await.unwrap();
174        let client_stream = client_handle.await.unwrap();
175
176        // Wrap in ConnectionStream
177        let mut server_conn = ConnectionStream::plain(server_stream);
178        let mut client_conn = ConnectionStream::plain(client_stream);
179
180        // Test writing and reading
181        client_conn.write_all(b"Hello").await.unwrap();
182
183        let mut buf = [0u8; 5];
184        server_conn.read_exact(&mut buf).await.unwrap();
185        assert_eq!(&buf, b"Hello");
186
187        // Test stream type checking
188        assert!(client_conn.is_unencrypted());
189        assert!(!client_conn.is_encrypted());
190        assert_eq!(client_conn.connection_type(), "TCP");
191        assert!(client_conn.as_tcp_stream().is_some());
192    }
193
194    #[test]
195    fn test_async_stream_trait() {
196        // Verify TcpStream implements AsyncStream
197        fn assert_async_stream<T: AsyncStream>() {}
198        assert_async_stream::<TcpStream>();
199        assert_async_stream::<ConnectionStream>();
200    }
201
202    #[tokio::test]
203    async fn test_connection_stream_tcp_access() {
204        let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
205        let addr = listener.local_addr().unwrap();
206
207        let tcp_stream = std::net::TcpStream::connect(addr).unwrap();
208        tcp_stream.set_nonblocking(true).unwrap();
209        let tokio_stream = TcpStream::from_std(tcp_stream).unwrap();
210
211        let mut conn_stream = ConnectionStream::plain(tokio_stream);
212
213        // Should be able to access underlying TCP stream
214        assert!(conn_stream.as_tcp_stream().is_some());
215        assert!(conn_stream.as_tcp_stream_mut().is_some());
216
217        // Test new API methods
218        assert!(conn_stream.is_unencrypted());
219        assert!(!conn_stream.is_encrypted());
220        assert_eq!(conn_stream.connection_type(), "TCP");
221
222        // Test underlying TCP access
223        let _underlying = conn_stream.underlying_tcp_stream();
224    }
225
226    #[tokio::test]
227    async fn test_connection_type_methods() {
228        // Test that the new API names are more explicit and clear
229        use std::net::TcpListener;
230
231        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
232        let addr = listener.local_addr().unwrap();
233
234        let tcp = std::net::TcpStream::connect(addr).unwrap();
235        tcp.set_nonblocking(true).unwrap();
236        let stream = TcpStream::from_std(tcp).unwrap();
237
238        let conn = ConnectionStream::plain(stream);
239
240        // New explicit method names
241        assert!(conn.is_unencrypted(), "Plain TCP should be unencrypted");
242        assert!(!conn.is_encrypted(), "Plain TCP should not be encrypted");
243        assert_eq!(conn.connection_type(), "TCP");
244    }
245
246    // TLS-specific tests - test type system without I/O
247    #[tokio::test]
248    async fn test_tls_connection_type_methods() {
249        // We can't easily create a real TLS connection in unit tests without a server,
250        // but we can test the code paths by examining what methods would return
251
252        // For Plain TCP
253        let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
254        let addr = listener.local_addr().unwrap();
255        let tcp = std::net::TcpStream::connect(addr).unwrap();
256        tcp.set_nonblocking(true).unwrap();
257        let stream = TcpStream::from_std(tcp).unwrap();
258
259        let conn = ConnectionStream::plain(stream);
260
261        // Test Plain TCP type checks
262        assert!(conn.is_unencrypted());
263        assert!(!conn.is_encrypted());
264        assert_eq!(conn.connection_type(), "TCP");
265
266        // Test Plain TCP stream access
267        assert!(conn.as_tcp_stream().is_some());
268        assert!(conn.as_tls_stream().is_none());
269    }
270
271    #[tokio::test]
272    async fn test_plain_tcp_stream_accessors() {
273        let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
274        let addr = listener.local_addr().unwrap();
275
276        let tcp = std::net::TcpStream::connect(addr).unwrap();
277        tcp.set_nonblocking(true).unwrap();
278        let stream = TcpStream::from_std(tcp).unwrap();
279
280        let mut conn = ConnectionStream::plain(stream);
281
282        // Test immutable access
283        assert!(conn.as_tcp_stream().is_some());
284        assert!(conn.as_tls_stream().is_none());
285
286        // Test mutable access
287        assert!(conn.as_tcp_stream_mut().is_some());
288        assert!(conn.as_tls_stream_mut().is_none());
289
290        // Test underlying TCP access (works for plain TCP)
291        let _underlying = conn.underlying_tcp_stream();
292    }
293
294    #[tokio::test]
295    async fn test_connection_stream_enum_variants() {
296        // Test that ConnectionStream enum has the expected variants
297        // This ensures the type structure is correct
298
299        let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
300        let addr = listener.local_addr().unwrap();
301        let tcp = std::net::TcpStream::connect(addr).unwrap();
302        tcp.set_nonblocking(true).unwrap();
303        let stream = TcpStream::from_std(tcp).unwrap();
304
305        let plain_conn = ConnectionStream::plain(stream);
306
307        // Match on the enum to ensure Plain variant exists
308        match plain_conn {
309            ConnectionStream::Plain(_) => {
310                assert!(true, "Plain variant works");
311            }
312            ConnectionStream::Tls(_) => {
313                panic!("Should be Plain, not Tls");
314            }
315        }
316    }
317
318    #[tokio::test]
319    async fn test_connection_type_string_values() {
320        let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
321        let addr = listener.local_addr().unwrap();
322
323        let tcp = std::net::TcpStream::connect(addr).unwrap();
324        tcp.set_nonblocking(true).unwrap();
325        let stream = TcpStream::from_std(tcp).unwrap();
326
327        let conn = ConnectionStream::plain(stream);
328
329        // Test connection_type returns correct string
330        assert_eq!(conn.connection_type(), "TCP");
331
332        // Verify it's the correct static string for logging
333        let type_str = conn.connection_type();
334        assert!(!type_str.is_empty());
335        assert!(type_str.len() < 10); // Reasonable length for logging
336    }
337
338    #[tokio::test]
339    async fn test_plain_connection_debug_format() {
340        use std::net::TcpListener;
341
342        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
343        let addr = listener.local_addr().unwrap();
344
345        let tcp = std::net::TcpStream::connect(addr).unwrap();
346        tcp.set_nonblocking(true).unwrap();
347        let stream = TcpStream::from_std(tcp).unwrap();
348
349        let conn = ConnectionStream::plain(stream);
350
351        // Test Debug implementation
352        let debug_str = format!("{:?}", conn);
353        assert!(
354            debug_str.contains("Plain"),
355            "Debug output should indicate Plain TCP"
356        );
357    }
358}