1use std::{io::Error as IoError, sync::Arc};
2
3use event_listener::Event;
4use futures_util::io::AsyncReadExt;
5use futures_util::stream::StreamExt;
6use log::debug;
7use log::error;
8use log::info;
9
10use fluvio_future::net::TcpStream;
11use fluvio_future::openssl::{DefaultServerTlsStream, TlsAcceptor};
12
13type TerminateEvent = Arc<Event>;
14
15use crate::authenticator::{Authenticator, NullAuthenticator};
16
17type SharedAuthenticator = Arc<Box<dyn Authenticator>>;
18
19pub async fn start(addr: &str, acceptor: TlsAcceptor, target: String) -> Result<(), IoError> {
21 let builder = ProxyBuilder::new(addr.to_string(), acceptor, target);
22 builder.start().await
23}
24
25pub async fn start_with_authenticator(
27 addr: &str,
28 acceptor: TlsAcceptor,
29 target: String,
30 authenticator: Box<dyn Authenticator>,
31) -> Result<(), IoError> {
32 let builder =
33 ProxyBuilder::new(addr.to_string(), acceptor, target).with_authenticator(authenticator);
34 builder.start().await
35}
36
37pub struct ProxyBuilder {
38 addr: String,
39 acceptor: TlsAcceptor,
40 target: String,
41 authenticator: Box<dyn Authenticator>,
42 terminate: TerminateEvent,
43}
44
45impl ProxyBuilder {
46 pub fn new(addr: String, acceptor: TlsAcceptor, target: String) -> Self {
47 Self {
48 addr,
49 acceptor,
50 target,
51 authenticator: Box::new(NullAuthenticator),
52 terminate: Arc::new(Event::new()),
53 }
54 }
55
56 pub fn with_authenticator(mut self, authenticator: Box<dyn Authenticator>) -> Self {
57 self.authenticator = authenticator;
58 self
59 }
60
61 pub fn with_terminate(mut self, terminate: TerminateEvent) -> Self {
62 self.terminate = terminate;
63 self
64 }
65
66 pub async fn start(self) -> Result<(), IoError> {
67 use tokio::select;
68
69 use fluvio_future::net::TcpListener;
70 use fluvio_future::task::spawn;
71
72 let listener = TcpListener::bind(&self.addr).await?;
73 info!("proxy started at: {}", self.addr);
74 let mut incoming = listener.incoming();
75 let shared_authenticator = Arc::new(self.authenticator);
76
77 loop {
78 select! {
79 _ = self.terminate.listen() => {
80 info!("terminate event received");
81 return Ok(());
82 }
83 incoming_stream = incoming.next() => {
84 if let Some(stream) = incoming_stream {
85 debug!("server: got connection from client");
86 if let Ok(tcp_stream) = stream {
87 let acceptor = self.acceptor.clone();
88 let target = self.target.clone();
89 spawn(process_stream(
90 acceptor,
91 tcp_stream,
92 target,
93 shared_authenticator.clone()
94 ));
95 } else {
96 error!("no stream detected");
97 return Ok(());
98 }
99
100 } else {
101 info!("no more incoming streaming");
102 return Ok(());
103 }
104 }
105
106 }
107 }
108 }
109}
110
111async fn process_stream(
113 acceptor: TlsAcceptor,
114 raw_stream: TcpStream,
115 target: String,
116 authenticator: SharedAuthenticator,
117) {
118 let source = raw_stream
119 .peer_addr()
120 .map(|addr| addr.to_string())
121 .unwrap_or_else(|_| "".to_owned());
122
123 debug!("new connection from {}", source);
124
125 let handshake = acceptor.accept(raw_stream).await;
126
127 match handshake {
128 Ok(inner_stream) => {
129 debug!("handshake success from: {}", source);
130 if let Err(err) = proxy(inner_stream, target, source.clone(), authenticator).await {
131 error!("error processing tls: {} from source: {}", err, source);
132 }
133 }
134 Err(err) => error!("error handshaking: {} from source: {}", err, source),
135 }
136}
137
138async fn proxy(
139 tls_stream: DefaultServerTlsStream,
140 target: String,
141 source: String,
142 authenticator: SharedAuthenticator,
143) -> Result<(), IoError> {
144 use crate::copy::copy;
145 use fluvio_future::task::spawn;
146
147 debug!(
148 "trying to connect to target at: {} from source: {}",
149 target, source
150 );
151 let tcp_stream = TcpStream::connect(&target).await?;
152
153 let auth_success = authenticator.authenticate(&tls_stream, &tcp_stream).await?;
154 if !auth_success {
155 debug!("authentication failed, dropping connection");
156 return Ok(());
157 } else {
158 debug!("authentication succeeded");
159 }
160
161 debug!("connect to target: {} from source: {}", target, source);
162
163 let (mut target_stream, mut target_sink) = tcp_stream.split();
164 let (mut from_tls_stream, mut from_tls_sink) = tls_stream.split();
165
166 let s_t = format!("{}->{}", source, target);
167 let t_s = format!("{}->{}", target, source);
168 let source_to_target_ft = async move {
169 match copy(&mut from_tls_stream, &mut target_sink, s_t.clone()).await {
170 Ok(len) => {
171 debug!("total {} bytes copied from source to target: {}", len, s_t);
172 }
173 Err(err) => {
174 error!("{} error copying: {}", s_t, err);
175 }
176 }
177 };
178
179 let target_to_source_ft = async move {
180 match copy(&mut target_stream, &mut from_tls_sink, t_s.clone()).await {
181 Ok(len) => {
182 debug!("total {} bytes copied from target: {}", len, t_s);
183 }
184 Err(err) => {
185 error!("{} error copying: {}", t_s, err);
186 }
187 }
188 };
189
190 spawn(source_to_target_ft);
191 spawn(target_to_source_ft);
192 Ok(())
193}