use std::{
collections::HashMap,
env, io,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
pin::Pin,
sync::Arc,
task::Poll,
time::Duration,
};
use crate::{
contracts::{TunnelConnectionMode, TunnelEndpoint, TunnelPort, TunnelRelayTunnelEndpoint},
management::{
Authorization, HttpError, TunnelLocator, TunnelManagementClient, TunnelRequestOptions,
NO_REQUEST_OPTIONS,
},
};
use async_trait::async_trait;
use futures::{stream::FuturesUnordered, StreamExt, TryFutureExt};
use russh::{server::Server as ServerTrait, CryptoVec};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::TcpStream,
sync::{mpsc, oneshot, watch},
task::JoinHandle,
};
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
use uuid::Uuid;
use super::{
errors::TunnelError,
ws::{build_websocket_request, connect_directly, connect_via_proxy, AsyncRWWebSocket},
};
type PortMap = HashMap<u32, mpsc::UnboundedSender<ForwardedPortConnection>>;
pub struct RelayTunnelHost {
pub proxy: Option<String>,
locator: TunnelLocator,
host_id: Uuid,
ports_tx: watch::Sender<PortMap>,
ports_rx: watch::Receiver<PortMap>,
mgmt: TunnelManagementClient,
host_keypair: russh_keys::key::KeyPair,
}
#[allow(dead_code)]
impl RelayTunnelHost {
pub fn new(locator: TunnelLocator, mgmt: TunnelManagementClient) -> Self {
let host_id = Uuid::new_v4();
let (ports_tx, ports_rx) = watch::channel(HashMap::new());
RelayTunnelHost {
proxy: env::var("HTTPS_PROXY").or(env::var("https_proxy")).ok(),
host_id,
locator,
ports_tx,
ports_rx,
mgmt,
host_keypair: russh_keys::key::KeyPair::generate_rsa(
2048,
russh_keys::key::SignatureHash::SHA2_512,
)
.expect("expected to generate rsa keypair"),
}
}
pub async fn connect(&mut self, host_token: &str) -> Result<RelayHandle, TunnelError> {
let (cnx, endpoint) = self.create_websocket(host_token).await?;
let cnx = AsyncRWWebSocket::new(super::ws::AsyncRWWebSocketOptions {
websocket: cnx,
ping_interval: Duration::from_secs(60),
ping_timeout: Duration::from_secs(10),
});
let (client_session, mut rx) = RelayTunnelHost::make_ssh_client(cnx)
.await
.map_err(TunnelError::TunnelRelayDisconnected)?;
let client_session = Arc::new(client_session);
let client_session_ret = client_session.clone();
log::debug!("established host relay primary session");
let mut channels = HashMap::new();
let ports_rx = self.ports_rx.clone();
let host_keypair = self.host_keypair.clone();
let join = tokio::spawn(async move {
let mut server = RelayTunnelHost::make_ssh_server(host_keypair.clone());
loop {
tokio::select! {
Some(op) = rx.recv() => match op {
ChannelOp::Open(id) => {
let (rw, sender) = AsyncRWChannel::new(id, client_session.clone());
server.run_stream(rw, ports_rx.clone());
channels.insert(id, sender);
log::info!("Opened new client on channel {}", id);
},
ChannelOp::Close(id) => {
channels.remove(&id);
},
ChannelOp::Data(id, data) => {
if let Some(ch) = channels.get(&id) {
if ch.send(data).is_err() { channels.remove(&id);
}
}
},
},
else => break,
}
}
client_session
.disconnect(russh::Disconnect::ByApplication, "going away", "en")
.await
.ok();
log::debug!("disconnected primary session after EOF");
Ok(())
});
Ok(RelayHandle {
endpoint,
join,
session: client_session_ret,
})
}
pub async fn unregister(&self) -> Result<bool, TunnelError> {
self.mgmt
.delete_tunnel_endpoints(
&self.locator,
&format!("{}-relay", &self.host_id.to_string()),
NO_REQUEST_OPTIONS,
)
.await
.map_err(|e| TunnelError::HttpError {
error: e,
reason: "could not unregister relay",
})
}
pub async fn add_port_raw(
&self,
port_to_add: &TunnelPort,
) -> Result<mpsc::UnboundedReceiver<ForwardedPortConnection>, TunnelError> {
let n = port_to_add.port_number as u32;
if self.ports_tx.borrow().get(&n).is_some() {
return Err(TunnelError::PortAlreadyExists(n));
}
let tunnel_port = self
.mgmt
.create_tunnel_port(&self.locator, port_to_add, NO_REQUEST_OPTIONS)
.await;
match tunnel_port {
Ok(_) => {}
Err(HttpError::ResponseError(e)) if e.status_code == 409 => {}
Err(e) => {
return Err(TunnelError::HttpError {
error: e,
reason: "failed to add port to tunnel",
})
}
}
let (tx, rx) = mpsc::unbounded_channel();
self.ports_tx.send_modify(|v| {
v.insert(n, tx);
});
Ok(rx)
}
pub async fn add_port(&self, port_to_add: &TunnelPort) -> Result<(), TunnelError> {
let rx = self.add_port_raw(port_to_add).await?;
tokio::spawn(forward_port_to_tcp(port_to_add.port_number, rx));
Ok(())
}
pub async fn remove_port(&self, port_number: u16) -> Result<(), TunnelError> {
self.mgmt
.delete_tunnel_port(&self.locator, port_number, NO_REQUEST_OPTIONS)
.await
.map_err(|e| TunnelError::HttpError {
error: e,
reason: "failed to remove port from tunnel",
})?;
self.ports_tx.send_modify(|v| {
v.remove(&(port_number as u32));
});
Ok(())
}
fn make_ssh_server(keypair: russh_keys::key::KeyPair) -> Server {
let c = russh::server::Config {
connection_timeout: None,
auth_rejection_time: std::time::Duration::from_secs(5),
keys: vec![keypair],
window_size: 1024 * 1024,
preferred: russh::Preferred::COMPRESSED,
limits: russh::Limits {
rekey_read_limit: usize::MAX,
rekey_time_limit: Duration::MAX,
rekey_write_limit: usize::MAX,
},
..Default::default()
};
let config = Arc::new(c);
Server { config }
}
async fn make_ssh_client(
rw: impl AsyncRead + AsyncWrite + Unpin + Send + 'static,
) -> Result<
(
russh::client::Handle<Client>,
mpsc::UnboundedReceiver<ChannelOp>,
),
russh::Error,
> {
let config = russh::client::Config {
anonymous: true,
window_size: 1024 * 1024 * 5,
preferred: russh::Preferred {
kex: &[russh::kex::NONE],
key: &[russh_keys::key::NONE],
cipher: &[russh::cipher::NONE],
mac: russh::Preferred::DEFAULT.mac,
compression: &["none"],
},
limits: russh::Limits {
rekey_read_limit: 1024 * 1024 * 8,
rekey_time_limit: std::time::Duration::from_secs(60),
rekey_write_limit: 1024 * 1024 * 8,
},
..Default::default()
};
let config = Arc::new(config);
let (client, rx) = Client::new();
let session = russh::client::connect_stream(config, rw, client).await?;
Ok((session, rx))
}
async fn create_websocket(
&self,
host_token: &str,
) -> Result<
(
WebSocketStream<MaybeTlsStream<TcpStream>>,
TunnelRelayTunnelEndpoint,
),
TunnelError,
> {
let endpoint = self
.mgmt
.update_tunnel_relay_endpoints(
&self.locator,
&TunnelRelayTunnelEndpoint {
base: TunnelEndpoint {
id: format!("{}-relay", self.host_id),
connection_mode: TunnelConnectionMode::TunnelRelay,
host_id: self.host_id.to_string(),
host_public_keys: vec![],
port_uri_format: None,
port_ssh_command_format: None,
ssh_gateway_public_key: None,
tunnel_ssh_command: None,
tunnel_uri: None,
},
client_relay_uri: None,
host_relay_uri: None,
},
&TunnelRequestOptions {
authorization: Some(Authorization::Tunnel(host_token.to_string())),
..TunnelRequestOptions::default()
},
)
.await
.map_err(|e| TunnelError::HttpError {
error: e,
reason: "failed to update tunnel endpoint for hosting",
})?;
let url = endpoint
.host_relay_uri
.as_deref()
.ok_or(TunnelError::MissingHostEndpoint)?;
let req = build_websocket_request(
url,
&[
("Sec-WebSocket-Protocol", "tunnel-relay-host"),
("Authorization", &format!("tunnel {}", host_token)),
("User-Agent", self.mgmt.user_agent.to_str().unwrap()),
],
)?;
let cnx = if let Some(proxy) = &self.proxy {
log::debug!("connecting via http_proxy on {}", proxy);
connect_via_proxy(req, proxy).await?
} else {
connect_directly(req).await?
};
Ok((cnx, endpoint))
}
}
pub struct ForwardedPortConnection {
port: u32,
channel: russh::ChannelId,
handle: russh::server::Handle,
receiver: mpsc::Receiver<Vec<u8>>,
}
impl ForwardedPortConnection {
pub async fn send(&mut self, d: &[u8]) -> Result<(), ()> {
self.handle
.data(self.channel, CryptoVec::from_slice(d))
.map_err(|_| ())
.await
}
pub async fn recv(&mut self) -> Option<Vec<u8>> {
self.receiver.recv().await
}
pub async fn close(self) {
self.handle.close(self.channel).await.ok();
}
pub fn into_rw(self) -> ForwardedPortRW {
let (w, r) = self.into_split();
ForwardedPortRW(r, w)
}
pub fn into_split(self) -> (ForwardedPortWriter, ForwardedPortReader) {
(
ForwardedPortWriter {
channel: self.channel,
handle: self.handle,
is_write_fut_valid: false,
write_fut: tokio_util::sync::ReusableBoxFuture::new(make_server_write_fut(None)),
},
ForwardedPortReader {
receiver: self.receiver,
readbuf: super::io::ReadBuffer::default(),
},
)
}
}
pub struct ForwardedPortWriter {
channel: russh::ChannelId,
handle: russh::server::Handle,
is_write_fut_valid: bool,
write_fut: tokio_util::sync::ReusableBoxFuture<'static, Result<(), russh::CryptoVec>>,
}
async fn make_server_write_fut(
data: Option<(russh::server::Handle, russh::ChannelId, Vec<u8>)>,
) -> Result<(), russh::CryptoVec> {
match data {
Some((client, id, data)) => client.data(id, CryptoVec::from(data)).await,
None => unreachable!("this future should not be pollable in this state"),
}
}
impl AsyncWrite for ForwardedPortWriter {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
if !self.is_write_fut_valid {
let handle = self.handle.clone();
let id = self.channel;
self.write_fut
.set(make_server_write_fut(Some((handle, id, buf.to_vec()))));
self.is_write_fut_valid = true;
}
self.poll_flush(cx).map(|r| r.map(|_| buf.len()))
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), io::Error>> {
if !self.is_write_fut_valid {
return Poll::Ready(Ok(()));
}
match self.write_fut.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(_)) => {
self.is_write_fut_valid = false;
Poll::Ready(Ok(()))
}
Poll::Ready(Err(_)) => {
self.is_write_fut_valid = false;
Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "EOF")))
}
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
}
pub struct ForwardedPortReader {
receiver: mpsc::Receiver<Vec<u8>>,
readbuf: super::io::ReadBuffer,
}
impl AsyncRead for ForwardedPortReader {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if let Some((v, s)) = self.readbuf.take_data() {
return self.readbuf.put_data(buf, v, s);
}
match self.receiver.poll_recv(cx) {
Poll::Ready(Some(msg)) => self.readbuf.put_data(buf, msg, 0),
Poll::Ready(None) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "EOF"))),
Poll::Pending => Poll::Pending,
}
}
}
pub struct ForwardedPortRW(ForwardedPortReader, ForwardedPortWriter);
impl AsyncRead for ForwardedPortRW {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}
impl AsyncWrite for ForwardedPortRW {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut self.1).poll_write(cx, buf)
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), io::Error>> {
Pin::new(&mut self.1).poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), io::Error>> {
Pin::new(&mut self.1).poll_shutdown(cx)
}
}
#[derive(Clone)]
struct Server {
config: Arc<russh::server::Config>,
}
impl Server {
pub fn run_stream(
&mut self,
rw: impl AsyncRead + AsyncWrite + Unpin + Send + 'static,
mut ports: watch::Receiver<PortMap>,
) -> JoinHandle<Result<(), russh::Error>> {
let mut server_session = self.new_client(None);
let mut server_connection_rx = server_session.take_rx().expect("expected to have tx");
let authed_tx = server_session.take_authed().expect("expected to have tx");
let config = self.config.clone();
tokio::spawn(async move {
log::debug!("starting to serve host relay client session");
let session = match russh::server::run_stream(config, rw, server_session).await {
Ok(s) => s,
Err(e) => {
log::error!("error handshaking session: {}", e);
return Err(e);
}
};
if authed_tx.await.is_err() {
log::debug!("connection closed before auth");
return Ok(()); }
log::debug!("host relay client session successfully authed");
let mut known_ports: PortMap = HashMap::new();
tokio::pin!(session);
loop {
tokio::select! {
r = &mut session => return r,
cnx = server_connection_rx.recv() => match cnx {
Some(cnx) => {
if let Some(p) = known_ports.get(&cnx.port) {
p.send(cnx).ok(); }
},
None => {
log::debug!("no more connections on host relay client session, ending");
return Ok(());
},
},
_ = ports.changed() => {
let new_ports = ports.borrow().clone();
for port in new_ports.keys() {
if !known_ports.contains_key(port) {
session.handle().forward_tcpip("127.0.0.1".to_string(), *port).await.ok();
}
}
for port in known_ports.keys() {
if !new_ports.contains_key(port) {
session.handle().cancel_forward_tcpip("127.0.0.1".to_string(), *port).await.ok();
}
}
known_ports = new_ports;
},
}
}
})
}
}
async fn forward_port_to_tcp(port: u16, mut rx: mpsc::UnboundedReceiver<ForwardedPortConnection>) {
let ipv4_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port);
let ipv6_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), port);
while let Some(mut conn) = rx.recv().await {
let mut futs = FuturesUnordered::new();
futs.push(TcpStream::connect(&ipv4_addr));
futs.push(TcpStream::connect(&ipv6_addr));
let mut last_result = None;
while let Some(r) = futs.next().await {
let ok = r.is_ok();
last_result = Some(r);
if ok {
break;
}
}
let mut stream = match last_result.unwrap() {
Ok(s) => s,
Err(e) => {
log::info!("Error connecting forwarding to port {}, {}", port, e);
conn.close().await;
continue;
}
};
log::debug!("Forwarded connection to port {}", port);
tokio::spawn(async move {
let mut read_buf = vec![0u8; 1024 * 64].into_boxed_slice();
loop {
tokio::select! {
n = stream.read(&mut read_buf) => match n {
Ok(0) => {
log::debug!("EOF from TCP stream, ending");
break;
},
Ok(n) => {
if (conn.send(&read_buf[..n]).await).is_err() {
log::debug!("channel was closed, ending forwarded port");
break;
}
},
Err(e) => {
log::debug!("error from TCP stream, ending: {}", e);
break;
}
},
m = conn.recv() => match m {
Some(data) => {
if let Err(e) = stream.write_all(&data).await {
log::debug!("error writing data to channel, ending: {}", e);
break;
}
},
None => {
log::debug!("EOF from channel, ending");
break;
}
}
}
}
});
}
}
impl ServerTrait for Server {
type Handler = ServerHandle;
fn new_client(&mut self, _: Option<std::net::SocketAddr>) -> ServerHandle {
ServerHandle::new()
}
}
struct ServerHandle {
authed_tx: Option<oneshot::Sender<()>>,
authed_rx: Option<oneshot::Receiver<()>>,
cnx_tx: mpsc::UnboundedSender<ForwardedPortConnection>,
cnx_rx: Option<mpsc::UnboundedReceiver<ForwardedPortConnection>>,
channel_senders: HashMap<russh::ChannelId, mpsc::Sender<Vec<u8>>>,
}
impl ServerHandle {
pub fn new() -> Self {
let (authed_tx, authed_rx) = oneshot::channel();
let (cnx_tx, cnx_rx) = mpsc::unbounded_channel();
Self {
authed_rx: Some(authed_rx),
authed_tx: Some(authed_tx),
cnx_rx: Some(cnx_rx),
cnx_tx,
channel_senders: HashMap::new(),
}
}
pub fn take_rx(&mut self) -> Option<mpsc::UnboundedReceiver<ForwardedPortConnection>> {
self.cnx_rx.take()
}
pub fn take_authed(&mut self) -> Option<oneshot::Receiver<()>> {
self.authed_rx.take()
}
}
#[async_trait]
impl russh::server::Handler for ServerHandle {
type Error = russh::Error;
async fn auth_succeeded(
mut self,
session: russh::server::Session,
) -> Result<(Self, russh::server::Session), Self::Error> {
if let Some(tx) = self.authed_tx.take() {
tx.send(()).ok();
}
Ok((self, session))
}
async fn auth_none(self, _: &str) -> Result<(Self, russh::server::Auth), Self::Error> {
Ok((self, russh::server::Auth::Accept))
}
async fn channel_open_forwarded_tcpip(
mut self,
channel: russh::Channel<russh::server::Msg>,
_host_to_connect: &str,
port_to_connect: u32,
_originator_address: &str,
_originator_port: u32,
session: russh::server::Session,
) -> Result<(Self, bool, russh::server::Session), Self::Error> {
let (sender, receiver) = mpsc::channel(10);
let txd = self.cnx_tx.send(ForwardedPortConnection {
port: port_to_connect,
channel: channel.id(),
handle: session.handle(),
receiver,
});
if txd.is_ok() {
self.channel_senders.insert(channel.id(), sender);
}
Ok((self, true, session))
}
async fn data(
mut self,
channel: russh::ChannelId,
data: &[u8],
session: russh::server::Session,
) -> Result<(Self, russh::server::Session), Self::Error> {
let data_vec = data.to_vec();
if let Some(sender) = self.channel_senders.get(&channel) {
if sender.send(data_vec).await.is_err() {
self.channel_senders.remove(&channel);
}
}
Ok((self, session))
}
}
#[derive(Debug)]
enum ChannelOp {
Open(russh::ChannelId),
Close(russh::ChannelId),
Data(russh::ChannelId, Vec<u8>),
}
struct Client {
sender: mpsc::UnboundedSender<ChannelOp>,
}
impl Client {
pub fn new() -> (Self, mpsc::UnboundedReceiver<ChannelOp>) {
let (tx, rx) = mpsc::unbounded_channel();
(Client { sender: tx }, rx)
}
}
#[async_trait]
impl russh::client::Handler for Client {
type Error = russh::Error;
async fn check_server_key(
self,
_server_public_key: &russh_keys::key::PublicKey,
) -> Result<(Self, bool), Self::Error> {
Ok((self, true))
}
fn server_channel_handle_unknown(
&self,
channel: russh::ChannelId,
channel_type: &[u8],
) -> bool {
if channel_type == b"client-ssh-session-stream" {
self.sender.send(ChannelOp::Open(channel)).ok();
true
} else {
false
}
}
async fn channel_close(
self,
channel: russh::ChannelId,
session: russh::client::Session,
) -> Result<(Self, russh::client::Session), Self::Error> {
self.sender.send(ChannelOp::Close(channel)).ok();
Ok((self, session))
}
async fn data(
self,
channel: russh::ChannelId,
data: &[u8],
session: russh::client::Session,
) -> Result<(Self, russh::client::Session), Self::Error> {
self.sender
.send(ChannelOp::Data(channel, data.to_vec()))
.ok();
Ok((self, session))
}
}
struct AsyncRWChannel {
id: russh::ChannelId,
session: Arc<russh::client::Handle<Client>>,
incoming: mpsc::UnboundedReceiver<Vec<u8>>,
readbuf: super::io::ReadBuffer,
is_write_fut_valid: bool,
write_fut: tokio_util::sync::ReusableBoxFuture<'static, Result<(), russh::CryptoVec>>,
}
impl AsyncRWChannel {
pub fn new(
id: russh::ChannelId,
session: Arc<russh::client::Handle<Client>>,
) -> (Self, mpsc::UnboundedSender<Vec<u8>>) {
let (tx, rx) = mpsc::unbounded_channel();
(
AsyncRWChannel {
id,
session,
incoming: rx,
readbuf: super::io::ReadBuffer::default(),
is_write_fut_valid: false,
write_fut: tokio_util::sync::ReusableBoxFuture::new(make_client_write_fut(None)),
},
tx,
)
}
}
async fn make_client_write_fut(
data: Option<(
Arc<russh::client::Handle<Client>>,
russh::ChannelId,
Vec<u8>,
)>,
) -> Result<(), russh::CryptoVec> {
match data {
Some((client, id, data)) => client.data(id, CryptoVec::from(data)).await,
None => unreachable!("this future should not be pollable in this state"),
}
}
impl AsyncWrite for AsyncRWChannel {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
if !self.is_write_fut_valid {
let session = self.session.clone();
let id = self.id;
self.write_fut
.set(make_client_write_fut(Some((session, id, buf.to_vec()))));
self.is_write_fut_valid = true;
}
self.poll_flush(cx).map(|r| r.map(|_| buf.len()))
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), io::Error>> {
if !self.is_write_fut_valid {
return Poll::Ready(Ok(()));
}
match self.write_fut.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(_)) => {
self.is_write_fut_valid = false;
Poll::Ready(Ok(()))
}
Poll::Ready(Err(_)) => {
self.is_write_fut_valid = false;
Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "EOF")))
}
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
}
impl AsyncRead for AsyncRWChannel {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if let Some((v, s)) = self.readbuf.take_data() {
return self.readbuf.put_data(buf, v, s);
}
match self.incoming.poll_recv(cx) {
Poll::Ready(Some(msg)) => self.readbuf.put_data(buf, msg, 0),
Poll::Ready(None) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "EOF"))),
Poll::Pending => Poll::Pending,
}
}
}
pub struct RelayHandle {
endpoint: TunnelRelayTunnelEndpoint,
session: Arc<russh::client::Handle<Client>>,
join: JoinHandle<Result<(), russh::Error>>,
}
impl RelayHandle {
pub fn endpoint(&self) -> &TunnelRelayTunnelEndpoint {
&self.endpoint
}
pub async fn close(self) -> Result<(), TunnelError> {
let result = self
.session
.disconnect(russh::Disconnect::ByApplication, "disconnect", "en")
.await;
self.join.await.ok();
result.map_err(TunnelError::TunnelRelayDisconnected)
}
}
impl std::future::Future for RelayHandle {
type Output = Result<(), TunnelError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
match std::future::Future::poll(Pin::new(&mut self.join), cx) {
Poll::Ready(r) => Poll::Ready(match r {
Ok(Ok(_)) => Ok(()),
Ok(Err(e)) => Err(TunnelError::TunnelRelayDisconnected(e)),
Err(_) => Ok(()),
}),
Poll::Pending => Poll::Pending,
}
}
}