ftp_rs/
connection.rs

1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
5use tokio::net::TcpStream;
6#[cfg(feature = "ftps")]
7use tokio_rustls::client::TlsStream;
8use pin_project::pin_project;
9
10#[pin_project(project = ConnectionProj)]
11pub enum Connection {
12    Tcp(#[pin] TcpStream),
13    #[cfg(feature = "ftps")]
14    Ssl(#[pin] TlsStream<TcpStream>),
15}
16
17impl Connection {
18    /// Unwrap the stream into TcpStream. This method is only used in secure connection.
19    pub fn into_tcp_stream(self) -> TcpStream {
20        match self {
21            Connection::Tcp(stream) => stream,
22            #[cfg(feature = "ftps")]
23            Connection::Ssl(stream) => stream.into_inner().0,
24        }
25    }
26
27    /// Test if the stream is secured
28    pub fn is_ssl(&self) -> bool {
29        match self {
30            #[cfg(feature = "ftps")]
31            Connection::Ssl(_) => true,
32            _ => false,
33        }
34    }
35
36    /// Returns a reference to the underlying TcpStream.
37    pub fn get_ref(&self) -> &TcpStream {
38        match self {
39            Connection::Tcp(ref stream) => stream,
40            #[cfg(feature = "ftps")]
41            Connection::Ssl(ref stream) => stream.get_ref().0,
42        }
43    }
44}
45
46impl AsyncRead for Connection {
47    fn poll_read(
48        self: Pin<&mut Self>,
49        cx: &mut Context<'_>,
50        buf: &mut ReadBuf<'_>,
51    ) -> Poll<io::Result<()>> {
52        match self.project() {
53            ConnectionProj::Tcp(stream) => stream.poll_read(cx, buf),
54            #[cfg(feature = "ftps")]
55            ConnectionProj::Ssl(stream) => stream.poll_read(cx, buf),
56        }
57    }
58}
59
60impl AsyncWrite for Connection {
61    fn poll_write(
62        self: Pin<&mut Self>,
63        cx: &mut Context<'_>,
64        buf: &[u8],
65    ) -> Poll<io::Result<usize>> {
66        match self.project() {
67            ConnectionProj::Tcp(stream) => stream.poll_write(cx, buf),
68            #[cfg(feature = "ftps")]
69            ConnectionProj::Ssl(stream) => stream.poll_write(cx, buf),
70        }
71    }
72
73    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
74        match self.project() {
75            ConnectionProj::Tcp(stream) => stream.poll_flush(cx),
76            #[cfg(feature = "ftps")]
77            ConnectionProj::Ssl(stream) => stream.poll_flush(cx),
78        }
79    }
80
81    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
82        match self.project() {
83            ConnectionProj::Tcp(stream) => stream.poll_shutdown(cx),
84            #[cfg(feature = "ftps")]
85            ConnectionProj::Ssl(stream) => stream.poll_shutdown(cx),
86        }
87    }
88}