1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
use std::{future::Future, net::SocketAddr};

use bytes::Bytes;
use futures_util::{Stream, StreamExt};
use tokio::{io, net::TcpStream};
use tokio_util::codec::{BytesCodec, FramedWrite};

/// Handle TCP connection, generic over TCP/TLS via `F`.
async fn handle_tcp<F, R, S, I>(
    addr: SocketAddr,
    f: F,
    receiver: &mut S,
) -> Result<(), std::io::Error>
where
    S: Stream<Item = Bytes>,
    S: Unpin,
    I: io::AsyncRead + io::AsyncWrite + Send + Unpin,
    F: FnOnce(TcpStream) -> R,
    R: Future<Output = Result<I, std::io::Error>> + Send,
{
    let tcp = TcpStream::connect(addr).await?;
    let wrapped = (f)(tcp).await?;
    let (_, writer) = io::split(wrapped);

    // Writer
    let sink = FramedWrite::new(writer, BytesCodec::new());
    receiver.map(Ok).forward(sink).await?;

    Ok(())
}

/// A TCP connection to Graylog.
#[derive(Debug)]
pub struct TcpConnection;

impl TcpConnection {
    pub(super) async fn handle<S>(
        &self,
        addr: SocketAddr,
        receiver: &mut S,
    ) -> Result<(), std::io::Error>
    where
        S: Stream<Item = Bytes> + Unpin,
    {
        let wrapper = |tcp_stream| async { Ok(tcp_stream) };
        handle_tcp(addr, wrapper, receiver).await
    }
}

/// A TLS connection to Graylog.
#[cfg(feature = "rustls-tls")]
pub struct TlsConnection {
    pub(crate) server_name: tokio_rustls::rustls::ServerName,
    pub(crate) client_config: std::sync::Arc<tokio_rustls::rustls::ClientConfig>,
}

#[cfg(feature = "rustls-tls")]
impl TlsConnection {
    pub(super) async fn handle<S>(
        &self,
        addr: SocketAddr,
        receiver: &mut S,
    ) -> Result<(), std::io::Error>
    where
        S: Stream<Item = Bytes> + Unpin,
    {
        let wrapper = move |tcp_stream| {
            let server_name = self.server_name.clone();
            let config = tokio_rustls::TlsConnector::from(self.client_config.clone());

            config.connect(server_name, tcp_stream)
        };
        handle_tcp(addr, wrapper, receiver).await
    }
}

#[cfg(feature = "rustls-tls")]
impl std::fmt::Debug for TlsConnection {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("TlsConnection")
            .field("server_name", &self.server_name)
            .finish_non_exhaustive()
    }
}