1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
use std::sync::Arc;
use anyhow::{bail, Context, Result};
use tokio::io::AsyncWriteExt;
use tokio::{net::TcpStream, time::timeout};
use tracing::{error, info, info_span, warn, Instrument};
use uuid::Uuid;
use crate::auth::Authenticator;
use crate::shared::{
proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT, NETWORK_TIMEOUT,
};
pub struct Client {
conn: Option<Delimited<TcpStream>>,
to: String,
local_host: String,
local_port: u16,
remote_port: u16,
auth: Option<Authenticator>,
}
impl Client {
pub async fn new(
local_host: &str,
local_port: u16,
to: &str,
port: u16,
secret: Option<&str>,
) -> Result<Self> {
let mut stream = Delimited::new(connect_with_timeout(to, CONTROL_PORT).await?);
let auth = secret.map(Authenticator::new);
if let Some(auth) = &auth {
auth.client_handshake(&mut stream).await?;
}
stream.send(ClientMessage::Hello(port)).await?;
let remote_port = match stream.recv_timeout().await? {
Some(ServerMessage::Hello(remote_port)) => remote_port,
Some(ServerMessage::Error(message)) => bail!("server error: {message}"),
Some(ServerMessage::Challenge(_)) => {
bail!("server requires authentication, but no client secret was provided");
}
Some(_) => bail!("unexpected initial non-hello message"),
None => bail!("unexpected EOF"),
};
info!(remote_port, "connected to server");
info!("listening at {to}:{remote_port}");
Ok(Client {
conn: Some(stream),
to: to.to_string(),
local_host: local_host.to_string(),
local_port,
remote_port,
auth,
})
}
pub fn remote_port(&self) -> u16 {
self.remote_port
}
pub async fn listen(mut self) -> Result<()> {
let mut conn = self.conn.take().unwrap();
let this = Arc::new(self);
loop {
match conn.recv().await? {
Some(ServerMessage::Hello(_)) => warn!("unexpected hello"),
Some(ServerMessage::Challenge(_)) => warn!("unexpected challenge"),
Some(ServerMessage::Heartbeat) => (),
Some(ServerMessage::Connection(id)) => {
let this = Arc::clone(&this);
tokio::spawn(
async move {
info!("new connection");
match this.handle_connection(id).await {
Ok(_) => info!("connection exited"),
Err(err) => warn!(%err, "connection exited with error"),
}
}
.instrument(info_span!("proxy", %id)),
);
}
Some(ServerMessage::Error(err)) => error!(%err, "server error"),
None => return Ok(()),
}
}
}
async fn handle_connection(&self, id: Uuid) -> Result<()> {
let mut remote_conn =
Delimited::new(connect_with_timeout(&self.to[..], CONTROL_PORT).await?);
if let Some(auth) = &self.auth {
auth.client_handshake(&mut remote_conn).await?;
}
remote_conn.send(ClientMessage::Accept(id)).await?;
let mut local_conn = connect_with_timeout(&self.local_host, self.local_port).await?;
let parts = remote_conn.into_parts();
debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty");
local_conn.write_all(&parts.read_buf).await?; proxy(local_conn, parts.io).await?;
Ok(())
}
}
async fn connect_with_timeout(to: &str, port: u16) -> Result<TcpStream> {
match timeout(NETWORK_TIMEOUT, TcpStream::connect((to, port))).await {
Ok(res) => res,
Err(err) => Err(err.into()),
}
.with_context(|| format!("could not connect to {to}:{port}"))
}