1use std::sync::Arc;
4
5use anyhow::{bail, Context, Result};
6use tokio::{io::AsyncWriteExt, net::TcpStream, time::timeout};
7use tracing::{error, info, info_span, warn, Instrument};
8use uuid::Uuid;
9
10use crate::auth::Authenticator;
11use crate::shared::{
12 proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT, NETWORK_TIMEOUT,
13};
14
15pub struct Client {
17 conn: Option<Delimited<TcpStream>>,
19
20 to: String,
22
23 local_host: String,
25
26 local_port: u16,
28
29 remote_port: u16,
31
32 auth: Option<Authenticator>,
34}
35
36impl Client {
37 pub async fn new(
39 local_host: &str,
40 local_port: u16,
41 to: &str,
42 port: u16,
43 secret: Option<&str>,
44 ) -> Result<Self> {
45 let mut stream = Delimited::new(connect_with_timeout(to, CONTROL_PORT).await?);
46 let auth = secret.map(Authenticator::new);
47 if let Some(auth) = &auth {
48 auth.client_handshake(&mut stream).await?;
49 }
50
51 stream.send(ClientMessage::Hello(port)).await?;
52 let remote_port = match stream.recv_timeout().await? {
53 Some(ServerMessage::Hello(remote_port)) => remote_port,
54 Some(ServerMessage::Error(message)) => bail!("server error: {message}"),
55 Some(ServerMessage::Challenge(_)) => {
56 bail!("server requires authentication, but no client secret was provided");
57 }
58 Some(_) => bail!("unexpected initial non-hello message"),
59 None => bail!("unexpected EOF"),
60 };
61 info!(remote_port, "connected to server");
62 info!("listening at {to}:{remote_port}");
63
64 Ok(Client {
65 conn: Some(stream),
66 to: to.to_string(),
67 local_host: local_host.to_string(),
68 local_port,
69 remote_port,
70 auth,
71 })
72 }
73
74 pub fn remote_port(&self) -> u16 {
76 self.remote_port
77 }
78
79 pub async fn listen(mut self) -> Result<()> {
81 let mut conn = self.conn.take().unwrap();
82 let this = Arc::new(self);
83 loop {
84 match conn.recv().await? {
85 Some(ServerMessage::Hello(_)) => warn!("unexpected hello"),
86 Some(ServerMessage::Challenge(_)) => warn!("unexpected challenge"),
87 Some(ServerMessage::Heartbeat) => (),
88 Some(ServerMessage::Connection(id)) => {
89 let this = Arc::clone(&this);
90 tokio::spawn(
91 async move {
92 info!("new connection");
93 match this.handle_connection(id).await {
94 Ok(_) => info!("connection exited"),
95 Err(err) => warn!(%err, "connection exited with error"),
96 }
97 }
98 .instrument(info_span!("proxy", %id)),
99 );
100 }
101 Some(ServerMessage::Error(err)) => error!(%err, "server error"),
102 None => return Ok(()),
103 }
104 }
105 }
106
107 async fn handle_connection(&self, id: Uuid) -> Result<()> {
108 let mut remote_conn =
109 Delimited::new(connect_with_timeout(&self.to[..], CONTROL_PORT).await?);
110 if let Some(auth) = &self.auth {
111 auth.client_handshake(&mut remote_conn).await?;
112 }
113 remote_conn.send(ClientMessage::Accept(id)).await?;
114 let mut local_conn = connect_with_timeout(&self.local_host, self.local_port).await?;
115 let parts = remote_conn.into_parts();
116 debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty");
117 local_conn.write_all(&parts.read_buf).await?; proxy(local_conn, parts.io).await?;
119 Ok(())
120 }
121}
122
123async fn connect_with_timeout(to: &str, port: u16) -> Result<TcpStream> {
124 match timeout(NETWORK_TIMEOUT, TcpStream::connect((to, port))).await {
125 Ok(res) => res,
126 Err(err) => Err(err.into()),
127 }
128 .with_context(|| format!("could not connect to {to}:{port}"))
129}