tracing_gelf/connection/
tcp.rs

1use std::{future::Future, net::SocketAddr};
2
3use bytes::Bytes;
4use futures_util::{Stream, StreamExt};
5use tokio::{io, net::TcpStream};
6use tokio_util::codec::{BytesCodec, FramedWrite};
7
8/// Handle TCP connection, generic over TCP/TLS via `F`.
9async fn handle_tcp<F, R, S, I>(
10    addr: SocketAddr,
11    f: F,
12    receiver: &mut S,
13) -> Result<(), std::io::Error>
14where
15    S: Stream<Item = Bytes>,
16    S: Unpin,
17    I: io::AsyncRead + io::AsyncWrite + Send + Unpin,
18    F: FnOnce(TcpStream) -> R,
19    R: Future<Output = Result<I, std::io::Error>> + Send,
20{
21    let tcp = TcpStream::connect(addr).await?;
22    let wrapped = (f)(tcp).await?;
23    let (_, writer) = io::split(wrapped);
24
25    // Writer
26    let sink = FramedWrite::new(writer, BytesCodec::new());
27    receiver.map(Ok).forward(sink).await?;
28
29    Ok(())
30}
31
32/// A TCP connection to Graylog.
33#[derive(Debug)]
34pub struct TcpConnection;
35
36impl TcpConnection {
37    pub(super) async fn handle<S>(
38        &self,
39        addr: SocketAddr,
40        receiver: &mut S,
41    ) -> Result<(), std::io::Error>
42    where
43        S: Stream<Item = Bytes> + Unpin,
44    {
45        let wrapper = |tcp_stream| async { Ok(tcp_stream) };
46        handle_tcp(addr, wrapper, receiver).await
47    }
48}
49
50/// A TLS connection to Graylog.
51#[cfg(feature = "rustls-tls")]
52pub struct TlsConnection {
53    pub(crate) server_name: rustls_pki_types::ServerName<'static>,
54    pub(crate) client_config: std::sync::Arc<tokio_rustls::rustls::ClientConfig>,
55}
56
57#[cfg(feature = "rustls-tls")]
58impl TlsConnection {
59    pub(super) async fn handle<S>(
60        &self,
61        addr: SocketAddr,
62        receiver: &mut S,
63    ) -> Result<(), std::io::Error>
64    where
65        S: Stream<Item = Bytes> + Unpin,
66    {
67        let wrapper = move |tcp_stream| {
68            let server_name = self.server_name.clone();
69            tokio_rustls::TlsConnector::from(self.client_config.clone())
70                .connect(server_name, tcp_stream)
71        };
72        handle_tcp(addr, wrapper, receiver).await
73    }
74}
75
76#[cfg(feature = "rustls-tls")]
77impl std::fmt::Debug for TlsConnection {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        f.debug_struct("TlsConnection")
80            .field("server_name", &self.server_name)
81            .finish_non_exhaustive()
82    }
83}