1use std::net::SocketAddr;
11use std::sync::Arc;
12use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
13use tokio::net::TcpListener;
14use tokio_rustls::TlsAcceptor;
15
16use super::registry::StreamRegistry;
17
18#[must_use]
27pub fn tls_acceptor_from_resolver(
28 resolver: Arc<dyn rustls::server::ResolvesServerCert>,
29) -> TlsAcceptor {
30 let config = rustls::ServerConfig::builder()
31 .with_no_client_auth()
32 .with_cert_resolver(resolver);
33 TlsAcceptor::from(Arc::new(config))
34}
35
36pub struct TcpStreamService {
41 registry: Arc<StreamRegistry>,
42 listen_port: u16,
43 tls_acceptor: Option<TlsAcceptor>,
47 proxy_protocol: bool,
50 local_addr: std::sync::OnceLock<SocketAddr>,
53}
54
55impl TcpStreamService {
56 #[must_use]
58 pub fn new(registry: Arc<StreamRegistry>, listen_port: u16) -> Self {
59 Self {
60 registry,
61 listen_port,
62 tls_acceptor: None,
63 proxy_protocol: false,
64 local_addr: std::sync::OnceLock::new(),
65 }
66 }
67
68 #[must_use]
70 pub fn with_tls_acceptor(mut self, acceptor: TlsAcceptor) -> Self {
71 self.tls_acceptor = Some(acceptor);
72 self
73 }
74
75 #[must_use]
77 pub fn with_proxy_protocol(mut self, enabled: bool) -> Self {
78 self.proxy_protocol = enabled;
79 self
80 }
81
82 #[must_use]
84 pub fn port(&self) -> u16 {
85 self.listen_port
86 }
87
88 #[must_use]
90 pub fn registry(&self) -> &Arc<StreamRegistry> {
91 &self.registry
92 }
93
94 pub async fn serve(self: Arc<Self>, listener: TcpListener) {
100 if let Ok(addr) = listener.local_addr() {
103 let _ = self.local_addr.set(addr);
104 }
105
106 tracing::info!(
107 port = self.listen_port,
108 tls = self.tls_acceptor.is_some(),
109 proxy_protocol = self.proxy_protocol,
110 "TCP stream proxy listening"
111 );
112
113 loop {
114 let (client_stream, client_addr) = match listener.accept().await {
115 Ok(conn) => conn,
116 Err(e) => {
117 tracing::warn!(
119 port = self.listen_port,
120 error = %e,
121 "TCP accept error, retrying"
122 );
123 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
124 continue;
125 }
126 };
127
128 let svc = Arc::clone(&self);
129 tokio::spawn(async move {
130 svc.handle_raw_connection(client_stream, client_addr).await;
131 });
132 }
133 }
134
135 async fn handle_raw_connection(
137 &self,
138 client_stream: tokio::net::TcpStream,
139 client_addr: SocketAddr,
140 ) {
141 let Some(service) = self.registry.resolve_tcp(self.listen_port) else {
143 tracing::warn!(
144 port = self.listen_port,
145 client = %client_addr,
146 "No service registered for TCP port"
147 );
148 return;
149 };
150
151 let Some(backend) = service.select_backend() else {
153 tracing::warn!(
154 port = self.listen_port,
155 service = %service.name,
156 client = %client_addr,
157 "No backends available for TCP service"
158 );
159 return;
160 };
161
162 tracing::debug!(
163 port = self.listen_port,
164 service = %service.name,
165 client = %client_addr,
166 backend = %backend,
167 "Proxying TCP connection"
168 );
169
170 let mut upstream = match tokio::net::TcpStream::connect(backend).await {
172 Ok(stream) => stream,
173 Err(e) => {
174 tracing::warn!(
175 error = %e,
176 backend = %backend,
177 service = %service.name,
178 client = %client_addr,
179 "Failed to connect to TCP backend"
180 );
181 return;
182 }
183 };
184
185 if self.proxy_protocol {
189 let dst = self
190 .local_addr
191 .get()
192 .copied()
193 .unwrap_or_else(|| SocketAddr::new(backend.ip(), self.listen_port));
194 let header = build_proxy_protocol_v2_header(client_addr, dst);
195 if let Err(e) = upstream.write_all(&header).await {
196 tracing::warn!(
197 error = %e,
198 backend = %backend,
199 service = %service.name,
200 client = %client_addr,
201 "Failed to write PROXY protocol header to backend"
202 );
203 return;
204 }
205 }
206
207 if let Some(acceptor) = &self.tls_acceptor {
210 match acceptor.accept(client_stream).await {
211 Ok(tls_stream) => {
212 Self::duplex(tls_stream, upstream).await;
213 }
214 Err(e) => {
215 tracing::warn!(
216 error = %e,
217 service = %service.name,
218 client = %client_addr,
219 "TLS handshake with client failed"
220 );
221 }
222 }
223 } else {
224 Self::duplex(client_stream, upstream).await;
225 }
226 }
227
228 async fn duplex<D, U>(mut downstream: D, mut upstream: U)
235 where
236 D: AsyncRead + AsyncWrite + Unpin,
237 U: AsyncRead + AsyncWrite + Unpin,
238 {
239 match tokio::io::copy_bidirectional(&mut downstream, &mut upstream).await {
240 Ok((down_to_up, up_to_down)) => {
241 tracing::debug!(
242 down_to_up = down_to_up,
243 up_to_down = up_to_down,
244 "TCP tunnel closed"
245 );
246 }
247 Err(e) => {
248 tracing::debug!(error = %e, "TCP tunnel error");
249 }
250 }
251 }
252
253 pub(crate) async fn splice<D, U>(downstream: D, upstream: U)
260 where
261 D: AsyncRead + AsyncWrite + Unpin,
262 U: AsyncRead + AsyncWrite + Unpin,
263 {
264 Self::duplex(downstream, upstream).await;
265 }
266}
267
268#[must_use]
284pub fn build_proxy_protocol_v2_header(src: SocketAddr, dst: SocketAddr) -> Vec<u8> {
285 const SIG: [u8; 12] = [
286 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
287 ];
288
289 let mut out = Vec::with_capacity(28);
290 out.extend_from_slice(&SIG);
291 out.push(0x21); match src {
294 SocketAddr::V4(src_v4) => {
295 out.push(0x11); out.extend_from_slice(&12u16.to_be_bytes()); let dst_ip = match dst {
299 SocketAddr::V4(d) => *d.ip(),
300 SocketAddr::V6(_) => std::net::Ipv4Addr::UNSPECIFIED,
301 };
302 out.extend_from_slice(&src_v4.ip().octets());
303 out.extend_from_slice(&dst_ip.octets());
304 out.extend_from_slice(&src_v4.port().to_be_bytes());
305 out.extend_from_slice(&dst.port().to_be_bytes());
306 }
307 SocketAddr::V6(src_v6) => {
308 out.push(0x21); out.extend_from_slice(&36u16.to_be_bytes()); let dst_ip = match dst {
312 SocketAddr::V6(d) => *d.ip(),
313 SocketAddr::V4(d) => d.ip().to_ipv6_mapped(),
314 };
315 out.extend_from_slice(&src_v6.ip().octets());
316 out.extend_from_slice(&dst_ip.octets());
317 out.extend_from_slice(&src_v6.port().to_be_bytes());
318 out.extend_from_slice(&dst.port().to_be_bytes());
319 }
320 }
321
322 out
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
329
330 #[test]
331 fn proxy_protocol_v2_ipv4_exact_bytes() {
332 let src = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 50), 0xABCD));
333 let dst = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 5432));
334 let hdr = build_proxy_protocol_v2_header(src, dst);
335
336 let expected: Vec<u8> = vec![
337 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
339 0x21, 0x11, 0x00, 0x0C, 192, 168, 1, 50, 10, 0, 0, 1, 0xAB, 0xCD, 0x15, 0x38, ];
347 assert_eq!(hdr, expected);
348 assert_eq!(hdr.len(), 16 + 12);
349 }
350
351 #[test]
352 fn proxy_protocol_v2_ipv6_shape() {
353 let src = SocketAddr::V6(SocketAddrV6::new(
354 Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1),
355 7777,
356 0,
357 0,
358 ));
359 let dst = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 8888, 0, 0));
360 let hdr = build_proxy_protocol_v2_header(src, dst);
361
362 assert_eq!(hdr.len(), 16 + 36);
364 assert_eq!(
365 &hdr[..12],
366 &[0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A]
367 );
368 assert_eq!(hdr[12], 0x21); assert_eq!(hdr[13], 0x21); assert_eq!(&hdr[14..16], &36u16.to_be_bytes());
371 assert_eq!(
373 &hdr[16..32],
374 &src.ip().to_string().parse::<Ipv6Addr>().unwrap().octets()
375 );
376 assert_eq!(&hdr[48..50], &7777u16.to_be_bytes());
378 assert_eq!(&hdr[50..52], &8888u16.to_be_bytes());
379 }
380}