bore_cli/
client.rs

1//! Client implementation for the `bore` service.
2
3use std::sync::Arc;
4
5use anyhow::{bail, Context, Result};
6use tokio::{io::AsyncWriteExt, net::TcpStream, time::timeout};
7use tracing::{error, info, info_span, warn, Instrument};
8use uuid::Uuid;
9
10use crate::auth::Authenticator;
11use crate::shared::{ClientMessage, Delimited, ServerMessage, CONTROL_PORT, NETWORK_TIMEOUT};
12
13/// State structure for the client.
14pub struct Client {
15    /// Control connection to the server.
16    conn: Option<Delimited<TcpStream>>,
17
18    /// Destination address of the server.
19    to: String,
20
21    // Local host that is forwarded.
22    local_host: String,
23
24    /// Local port that is forwarded.
25    local_port: u16,
26
27    /// Port that is publicly available on the remote.
28    remote_port: u16,
29
30    /// Optional secret used to authenticate clients.
31    auth: Option<Authenticator>,
32}
33
34impl Client {
35    /// Create a new client.
36    pub async fn new(
37        local_host: &str,
38        local_port: u16,
39        to: &str,
40        port: u16,
41        secret: Option<&str>,
42    ) -> Result<Self> {
43        let mut stream = Delimited::new(connect_with_timeout(to, CONTROL_PORT).await?);
44        let auth = secret.map(Authenticator::new);
45        if let Some(auth) = &auth {
46            auth.client_handshake(&mut stream).await?;
47        }
48
49        stream.send(ClientMessage::Hello(port)).await?;
50        let remote_port = match stream.recv_timeout().await? {
51            Some(ServerMessage::Hello(remote_port)) => remote_port,
52            Some(ServerMessage::Error(message)) => bail!("server error: {message}"),
53            Some(ServerMessage::Challenge(_)) => {
54                bail!("server requires authentication, but no client secret was provided");
55            }
56            Some(_) => bail!("unexpected initial non-hello message"),
57            None => bail!("unexpected EOF"),
58        };
59        info!(remote_port, "connected to server");
60        info!("listening at {to}:{remote_port}");
61
62        Ok(Client {
63            conn: Some(stream),
64            to: to.to_string(),
65            local_host: local_host.to_string(),
66            local_port,
67            remote_port,
68            auth,
69        })
70    }
71
72    /// Returns the port publicly available on the remote.
73    pub fn remote_port(&self) -> u16 {
74        self.remote_port
75    }
76
77    /// Start the client, listening for new connections.
78    pub async fn listen(mut self) -> Result<()> {
79        let mut conn = self.conn.take().unwrap();
80        let this = Arc::new(self);
81        loop {
82            match conn.recv().await? {
83                Some(ServerMessage::Hello(_)) => warn!("unexpected hello"),
84                Some(ServerMessage::Challenge(_)) => warn!("unexpected challenge"),
85                Some(ServerMessage::Heartbeat) => (),
86                Some(ServerMessage::Connection(id)) => {
87                    let this = Arc::clone(&this);
88                    tokio::spawn(
89                        async move {
90                            info!("new connection");
91                            match this.handle_connection(id).await {
92                                Ok(_) => info!("connection exited"),
93                                Err(err) => warn!(%err, "connection exited with error"),
94                            }
95                        }
96                        .instrument(info_span!("proxy", %id)),
97                    );
98                }
99                Some(ServerMessage::Error(err)) => error!(%err, "server error"),
100                None => return Ok(()),
101            }
102        }
103    }
104
105    async fn handle_connection(&self, id: Uuid) -> Result<()> {
106        let mut remote_conn =
107            Delimited::new(connect_with_timeout(&self.to[..], CONTROL_PORT).await?);
108        if let Some(auth) = &self.auth {
109            auth.client_handshake(&mut remote_conn).await?;
110        }
111        remote_conn.send(ClientMessage::Accept(id)).await?;
112        let mut local_conn = connect_with_timeout(&self.local_host, self.local_port).await?;
113        let mut parts = remote_conn.into_parts();
114        debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty");
115        local_conn.write_all(&parts.read_buf).await?; // mostly of the cases, this will be empty
116        tokio::io::copy_bidirectional(&mut local_conn, &mut parts.io).await?;
117        Ok(())
118    }
119}
120
121async fn connect_with_timeout(to: &str, port: u16) -> Result<TcpStream> {
122    match timeout(NETWORK_TIMEOUT, TcpStream::connect((to, port))).await {
123        Ok(res) => res,
124        Err(err) => Err(err.into()),
125    }
126    .with_context(|| format!("could not connect to {to}:{port}"))
127}