flv_tls_proxy/
proxy.rs

1use std::{io::Error as IoError, sync::Arc};
2
3use event_listener::Event;
4use futures_util::io::AsyncReadExt;
5use futures_util::stream::StreamExt;
6use log::debug;
7use log::error;
8use log::info;
9
10use fluvio_future::net::TcpStream;
11use fluvio_future::openssl::{DefaultServerTlsStream, TlsAcceptor};
12
13type TerminateEvent = Arc<Event>;
14
15use crate::authenticator::{Authenticator, NullAuthenticator};
16
17type SharedAuthenticator = Arc<Box<dyn Authenticator>>;
18
19/// start TLS proxy at addr to target
20pub async fn start(addr: &str, acceptor: TlsAcceptor, target: String) -> Result<(), IoError> {
21    let builder = ProxyBuilder::new(addr.to_string(), acceptor, target);
22    builder.start().await
23}
24
25/// start TLS proxy with authenticator at addr to target
26pub async fn start_with_authenticator(
27    addr: &str,
28    acceptor: TlsAcceptor,
29    target: String,
30    authenticator: Box<dyn Authenticator>,
31) -> Result<(), IoError> {
32    let builder =
33        ProxyBuilder::new(addr.to_string(), acceptor, target).with_authenticator(authenticator);
34    builder.start().await
35}
36
37pub struct ProxyBuilder {
38    addr: String,
39    acceptor: TlsAcceptor,
40    target: String,
41    authenticator: Box<dyn Authenticator>,
42    terminate: TerminateEvent,
43}
44
45impl ProxyBuilder {
46    pub fn new(addr: String, acceptor: TlsAcceptor, target: String) -> Self {
47        Self {
48            addr,
49            acceptor,
50            target,
51            authenticator: Box::new(NullAuthenticator),
52            terminate: Arc::new(Event::new()),
53        }
54    }
55
56    pub fn with_authenticator(mut self, authenticator: Box<dyn Authenticator>) -> Self {
57        self.authenticator = authenticator;
58        self
59    }
60
61    pub fn with_terminate(mut self, terminate: TerminateEvent) -> Self {
62        self.terminate = terminate;
63        self
64    }
65
66    pub async fn start(self) -> Result<(), IoError> {
67        use tokio::select;
68
69        use fluvio_future::net::TcpListener;
70        use fluvio_future::task::spawn;
71
72        let listener = TcpListener::bind(&self.addr).await?;
73        info!("proxy started at: {}", self.addr);
74        let mut incoming = listener.incoming();
75        let shared_authenticator = Arc::new(self.authenticator);
76
77        loop {
78            select! {
79                _ = self.terminate.listen() => {
80                    info!("terminate event received");
81                    return Ok(());
82                }
83                incoming_stream = incoming.next() => {
84                    if let Some(stream) = incoming_stream {
85                        debug!("server: got connection from client");
86                        if let Ok(tcp_stream) = stream {
87                            let acceptor = self.acceptor.clone();
88                            let target = self.target.clone();
89                            spawn(process_stream(
90                                acceptor,
91                                tcp_stream,
92                                target,
93                                shared_authenticator.clone()
94                            ));
95                        } else {
96                            error!("no stream detected");
97                            return Ok(());
98                        }
99
100                    } else {
101                        info!("no more incoming streaming");
102                        return Ok(());
103                    }
104                }
105
106            }
107        }
108    }
109}
110
111/// start TLS stream at addr to target
112async fn process_stream(
113    acceptor: TlsAcceptor,
114    raw_stream: TcpStream,
115    target: String,
116    authenticator: SharedAuthenticator,
117) {
118    let source = raw_stream
119        .peer_addr()
120        .map(|addr| addr.to_string())
121        .unwrap_or_else(|_| "".to_owned());
122
123    debug!("new connection from {}", source);
124
125    let handshake = acceptor.accept(raw_stream).await;
126
127    match handshake {
128        Ok(inner_stream) => {
129            debug!("handshake success from: {}", source);
130            if let Err(err) = proxy(inner_stream, target, source.clone(), authenticator).await {
131                error!("error processing tls: {} from source: {}", err, source);
132            }
133        }
134        Err(err) => error!("error handshaking: {} from source: {}", err, source),
135    }
136}
137
138async fn proxy(
139    tls_stream: DefaultServerTlsStream,
140    target: String,
141    source: String,
142    authenticator: SharedAuthenticator,
143) -> Result<(), IoError> {
144    use crate::copy::copy;
145    use fluvio_future::task::spawn;
146
147    debug!(
148        "trying to connect to target at: {} from source: {}",
149        target, source
150    );
151    let tcp_stream = TcpStream::connect(&target).await?;
152
153    let auth_success = authenticator.authenticate(&tls_stream, &tcp_stream).await?;
154    if !auth_success {
155        debug!("authentication failed, dropping connection");
156        return Ok(());
157    } else {
158        debug!("authentication succeeded");
159    }
160
161    debug!("connect to target: {} from source: {}", target, source);
162
163    let (mut target_stream, mut target_sink) = tcp_stream.split();
164    let (mut from_tls_stream, mut from_tls_sink) = tls_stream.split();
165
166    let s_t = format!("{}->{}", source, target);
167    let t_s = format!("{}->{}", target, source);
168    let source_to_target_ft = async move {
169        match copy(&mut from_tls_stream, &mut target_sink, s_t.clone()).await {
170            Ok(len) => {
171                debug!("total {} bytes copied from source to target: {}", len, s_t);
172            }
173            Err(err) => {
174                error!("{} error copying: {}", s_t, err);
175            }
176        }
177    };
178
179    let target_to_source_ft = async move {
180        match copy(&mut target_stream, &mut from_tls_sink, t_s.clone()).await {
181            Ok(len) => {
182                debug!("total {} bytes copied from target: {}", len, t_s);
183            }
184            Err(err) => {
185                error!("{} error copying: {}", t_s, err);
186            }
187        }
188    };
189
190    spawn(source_to_target_ft);
191    spawn(target_to_source_ft);
192    Ok(())
193}