flv_tls_proxy/
proxy.rs

1use std::{io::Error as IoError, sync::Arc};
2
3use anyhow::Result;
4use event_listener::Event;
5use futures_util::io::AsyncReadExt;
6use futures_util::stream::StreamExt;
7use tracing::{debug, error, info, instrument};
8
9use fluvio_future::net::TcpStream;
10use fluvio_future::openssl::{DefaultServerTlsStream, TlsAcceptor};
11
12type TerminateEvent = Arc<Event>;
13
14use crate::authenticator::{Authenticator, NullAuthenticator};
15
16type SharedAuthenticator = Arc<Box<dyn Authenticator>>;
17
18/// start TLS proxy at addr to target
19pub async fn start(addr: &str, acceptor: TlsAcceptor, target: String) -> Result<(), IoError> {
20    let builder = ProxyBuilder::new(addr.to_string(), acceptor, target);
21    builder.start().await
22}
23
24/// start TLS proxy with authenticator at addr to target
25pub async fn start_with_authenticator(
26    addr: &str,
27    acceptor: TlsAcceptor,
28    target: String,
29    authenticator: Box<dyn Authenticator>,
30) -> Result<(), IoError> {
31    let builder =
32        ProxyBuilder::new(addr.to_string(), acceptor, target).with_authenticator(authenticator);
33    builder.start().await
34}
35
36pub struct ProxyBuilder {
37    addr: String,
38    acceptor: TlsAcceptor,
39    target: String,
40    authenticator: Box<dyn Authenticator>,
41    terminate: TerminateEvent,
42}
43
44impl ProxyBuilder {
45    pub fn new(addr: String, acceptor: TlsAcceptor, target: String) -> Self {
46        Self {
47            addr,
48            acceptor,
49            target,
50            authenticator: Box::new(NullAuthenticator),
51            terminate: Arc::new(Event::new()),
52        }
53    }
54
55    pub fn with_authenticator(mut self, authenticator: Box<dyn Authenticator>) -> Self {
56        self.authenticator = authenticator;
57        self
58    }
59
60    pub fn with_terminate(mut self, terminate: TerminateEvent) -> Self {
61        self.terminate = terminate;
62        self
63    }
64
65    #[instrument(skip(self))]
66    pub async fn start(self) -> Result<(), IoError> {
67        use tokio::select;
68
69        use fluvio_future::net::TcpListener;
70        use fluvio_future::task::spawn;
71
72        let listener = TcpListener::bind(&self.addr).await?;
73        info!(self.addr, "proxy started at");
74        let mut incoming = listener.incoming();
75        let shared_authenticator = Arc::new(self.authenticator);
76
77        loop {
78            select! {
79                _ = self.terminate.listen() => {
80                    info!("terminate event received");
81                    return Ok(());
82                }
83                incoming_stream = incoming.next() => {
84                    if let Some(stream) = incoming_stream {
85                        debug!("server: got connection from client");
86                        if let Ok(tcp_stream) = stream {
87                            let acceptor = self.acceptor.clone();
88                            let target = self.target.clone();
89                            spawn(process_stream(
90                                acceptor,
91                                tcp_stream,
92                                target,
93                                shared_authenticator.clone()
94                            ));
95                        } else {
96                            error!("no stream detected");
97                            return Ok(());
98                        }
99
100                    } else {
101                        info!("no more incoming streaming");
102                        return Ok(());
103                    }
104                }
105
106            }
107        }
108    }
109}
110
111/// start TLS stream at addr to target
112#[instrument(skip(acceptor, raw_stream, authenticator))]
113async fn process_stream(
114    acceptor: TlsAcceptor,
115    raw_stream: TcpStream,
116    target: String,
117    authenticator: SharedAuthenticator,
118) {
119    let source = raw_stream
120        .peer_addr()
121        .map(|addr| addr.to_string())
122        .unwrap_or_else(|_| "".to_owned());
123
124    info!(source, "new connection from");
125
126    let handshake = acceptor.accept(raw_stream).await;
127
128    match handshake {
129        Ok(inner_stream) => {
130            info!(source, "handshake success");
131            if let Err(err) = proxy(inner_stream, target, source.clone(), authenticator).await {
132                error!("error processing tls: {} from source: {}", err, source);
133            }
134        }
135        Err(err) => error!("error handshaking: {} from source: {}", err, source),
136    }
137}
138
139#[instrument(skip(tls_stream, authenticator))]
140async fn proxy(
141    tls_stream: DefaultServerTlsStream,
142    target: String,
143    source: String,
144    authenticator: SharedAuthenticator,
145) -> Result<()> {
146    use crate::copy::copy;
147    use fluvio_future::task::spawn;
148
149    debug!("trying to connect to target");
150    let tcp_stream = TcpStream::connect(&target).await?;
151    info!("open tcp stream");
152
153    let auth_success = authenticator.authenticate(&tls_stream, &tcp_stream).await?;
154    if !auth_success {
155        info!("authentication failed, dropping connection");
156        return Ok(());
157    } else {
158        info!("authentication succeeded");
159    }
160
161    let (mut target_stream, mut target_sink) = tcp_stream.split();
162    let (mut from_tls_stream, mut from_tls_sink) = tls_stream.split();
163
164    let s_t = format!("{}->{}", source, target);
165    let t_s = format!("{}->{}", target, source);
166    let source_to_target_ft = async move {
167        match copy(&mut from_tls_stream, &mut target_sink, s_t.clone()).await {
168            Ok(len) => {
169                debug!(len, s_t, "total bytes copied from source to target");
170            }
171            Err(err) => {
172                error!("{} error copying: {}", s_t, err);
173            }
174        }
175    };
176
177    let target_to_source_ft = async move {
178        match copy(&mut target_stream, &mut from_tls_sink, t_s.clone()).await {
179            Ok(len) => {
180                debug!(len, t_s, "total bytes copied from target");
181            }
182            Err(err) => {
183                error!("{} error copying: {}", t_s, err);
184            }
185        }
186    };
187
188    spawn(source_to_target_ft);
189    spawn(target_to_source_ft);
190    Ok(())
191}