1use std::sync::Arc;
4
5use anyhow::{bail, Context, Result};
6use tokio::{io::AsyncWriteExt, net::TcpStream, time::timeout};
7use tracing::{error, info, info_span, warn, Instrument};
8use uuid::Uuid;
9
10use crate::auth::Authenticator;
11use crate::shared::{ClientMessage, Delimited, ServerMessage, CONTROL_PORT, NETWORK_TIMEOUT};
12
13pub struct Client {
15 conn: Option<Delimited<TcpStream>>,
17
18 to: String,
20
21 local_host: String,
23
24 local_port: u16,
26
27 remote_port: u16,
29
30 auth: Option<Authenticator>,
32}
33
34impl Client {
35 pub async fn new(
37 local_host: &str,
38 local_port: u16,
39 to: &str,
40 port: u16,
41 secret: Option<&str>,
42 ) -> Result<Self> {
43 let mut stream = Delimited::new(connect_with_timeout(to, CONTROL_PORT).await?);
44 let auth = secret.map(Authenticator::new);
45 if let Some(auth) = &auth {
46 auth.client_handshake(&mut stream).await?;
47 }
48
49 stream.send(ClientMessage::Hello(port)).await?;
50 let remote_port = match stream.recv_timeout().await? {
51 Some(ServerMessage::Hello(remote_port)) => remote_port,
52 Some(ServerMessage::Error(message)) => bail!("server error: {message}"),
53 Some(ServerMessage::Challenge(_)) => {
54 bail!("server requires authentication, but no client secret was provided");
55 }
56 Some(_) => bail!("unexpected initial non-hello message"),
57 None => bail!("unexpected EOF"),
58 };
59 info!(remote_port, "connected to server");
60 info!("listening at {to}:{remote_port}");
61
62 Ok(Client {
63 conn: Some(stream),
64 to: to.to_string(),
65 local_host: local_host.to_string(),
66 local_port,
67 remote_port,
68 auth,
69 })
70 }
71
72 pub fn remote_port(&self) -> u16 {
74 self.remote_port
75 }
76
77 pub async fn listen(mut self) -> Result<()> {
79 let mut conn = self.conn.take().unwrap();
80 let this = Arc::new(self);
81 loop {
82 match conn.recv().await? {
83 Some(ServerMessage::Hello(_)) => warn!("unexpected hello"),
84 Some(ServerMessage::Challenge(_)) => warn!("unexpected challenge"),
85 Some(ServerMessage::Heartbeat) => (),
86 Some(ServerMessage::Connection(id)) => {
87 let this = Arc::clone(&this);
88 tokio::spawn(
89 async move {
90 info!("new connection");
91 match this.handle_connection(id).await {
92 Ok(_) => info!("connection exited"),
93 Err(err) => warn!(%err, "connection exited with error"),
94 }
95 }
96 .instrument(info_span!("proxy", %id)),
97 );
98 }
99 Some(ServerMessage::Error(err)) => error!(%err, "server error"),
100 None => return Ok(()),
101 }
102 }
103 }
104
105 async fn handle_connection(&self, id: Uuid) -> Result<()> {
106 let mut remote_conn =
107 Delimited::new(connect_with_timeout(&self.to[..], CONTROL_PORT).await?);
108 if let Some(auth) = &self.auth {
109 auth.client_handshake(&mut remote_conn).await?;
110 }
111 remote_conn.send(ClientMessage::Accept(id)).await?;
112 let mut local_conn = connect_with_timeout(&self.local_host, self.local_port).await?;
113 let mut parts = remote_conn.into_parts();
114 debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty");
115 local_conn.write_all(&parts.read_buf).await?; tokio::io::copy_bidirectional(&mut local_conn, &mut parts.io).await?;
117 Ok(())
118 }
119}
120
121async fn connect_with_timeout(to: &str, port: u16) -> Result<TcpStream> {
122 match timeout(NETWORK_TIMEOUT, TcpStream::connect((to, port))).await {
123 Ok(res) => res,
124 Err(err) => Err(err.into()),
125 }
126 .with_context(|| format!("could not connect to {to}:{port}"))
127}