1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
use std::{future::Future, net::SocketAddr};
use bytes::Bytes;
use futures_util::{Stream, StreamExt};
use tokio::{io, net::TcpStream};
use tokio_util::codec::{BytesCodec, FramedWrite};
async fn handle_tcp<F, R, S, I>(
addr: SocketAddr,
f: F,
receiver: &mut S,
) -> Result<(), std::io::Error>
where
S: Stream<Item = Bytes>,
S: Unpin,
I: io::AsyncRead + io::AsyncWrite + Send + Unpin,
F: FnOnce(TcpStream) -> R,
R: Future<Output = Result<I, std::io::Error>> + Send,
{
let tcp = TcpStream::connect(addr).await?;
let wrapped = (f)(tcp).await?;
let (_, writer) = io::split(wrapped);
let sink = FramedWrite::new(writer, BytesCodec::new());
receiver.map(Ok).forward(sink).await?;
Ok(())
}
#[derive(Debug)]
pub struct TcpConnection;
impl TcpConnection {
pub(super) async fn handle<S>(
&self,
addr: SocketAddr,
receiver: &mut S,
) -> Result<(), std::io::Error>
where
S: Stream<Item = Bytes> + Unpin,
{
let wrapper = |tcp_stream| async { Ok(tcp_stream) };
handle_tcp(addr, wrapper, receiver).await
}
}
#[cfg(feature = "rustls-tls")]
pub struct TlsConnection {
pub(crate) server_name: tokio_rustls::rustls::ServerName,
pub(crate) client_config: std::sync::Arc<tokio_rustls::rustls::ClientConfig>,
}
#[cfg(feature = "rustls-tls")]
impl TlsConnection {
pub(super) async fn handle<S>(
&self,
addr: SocketAddr,
receiver: &mut S,
) -> Result<(), std::io::Error>
where
S: Stream<Item = Bytes> + Unpin,
{
let wrapper = move |tcp_stream| {
let server_name = self.server_name.clone();
let config = tokio_rustls::TlsConnector::from(self.client_config.clone());
config.connect(server_name, tcp_stream)
};
handle_tcp(addr, wrapper, receiver).await
}
}
#[cfg(feature = "rustls-tls")]
impl std::fmt::Debug for TlsConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TlsConnection")
.field("server_name", &self.server_name)
.finish_non_exhaustive()
}
}