use std::collections::HashMap;
use std::pin::pin;
use anyhow::{Context, Result};
use sshx_core::proto::{
client_update::ClientMessage, server_update::ServerMessage,
sshx_service_client::SshxServiceClient, ClientUpdate, CloseRequest, NewShell, OpenRequest,
};
use sshx_core::{rand_alphanumeric, Sid};
use tokio::sync::mpsc;
use tokio::task;
use tokio::time::{self, Duration, Instant, MissedTickBehavior};
use tokio_stream::{wrappers::ReceiverStream, StreamExt};
use tonic::transport::Channel;
use tracing::{debug, error, warn};
use crate::encrypt::Encrypt;
use crate::runner::{Runner, ShellData};
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(2);
const RECONNECT_INTERVAL: Duration = Duration::from_secs(60);
pub struct Controller {
origin: String,
runner: Runner,
encrypt: Encrypt,
encryption_key: String,
name: String,
token: String,
url: String,
shells_tx: HashMap<Sid, mpsc::Sender<ShellData>>,
output_tx: mpsc::Sender<ClientMessage>,
output_rx: mpsc::Receiver<ClientMessage>,
}
impl Controller {
pub async fn new(origin: &str, name: &str, runner: Runner) -> Result<Self> {
debug!(%origin, "connecting to server");
let encryption_key = rand_alphanumeric(14);
let kdf_task = {
let encryption_key = encryption_key.clone();
task::spawn_blocking(move || Encrypt::new(&encryption_key))
};
let mut client = Self::connect(origin).await?;
let encrypt = kdf_task.await?;
let req = OpenRequest {
origin: origin.into(),
encrypted_zeros: encrypt.zeros().into(),
name: name.into(),
};
let mut resp = client.open(req).await?.into_inner();
resp.url = resp.url + "#" + &encryption_key;
let (output_tx, output_rx) = mpsc::channel(64);
Ok(Self {
origin: origin.into(),
runner,
encrypt,
encryption_key,
name: resp.name,
token: resp.token,
url: resp.url,
shells_tx: HashMap::new(),
output_tx,
output_rx,
})
}
async fn connect(origin: &str) -> Result<SshxServiceClient<Channel>, tonic::transport::Error> {
SshxServiceClient::connect(String::from(origin)).await
}
pub fn name(&self) -> &str {
&self.name
}
pub fn url(&self) -> &str {
&self.url
}
pub fn encryption_key(&self) -> &str {
&self.encryption_key
}
pub async fn run(&mut self) -> ! {
let mut last_retry = Instant::now();
let mut retries = 0;
loop {
if let Err(err) = self.try_channel().await {
if last_retry.elapsed() >= Duration::from_secs(10) {
retries = 0;
}
let secs = 2_u64.pow(retries.min(4));
error!(%err, "disconnected, retrying in {secs}s...");
time::sleep(Duration::from_secs(secs)).await;
retries += 1;
}
last_retry = Instant::now();
}
}
async fn try_channel(&mut self) -> Result<()> {
let (tx, rx) = mpsc::channel(16);
let hello = ClientMessage::Hello(format!("{},{}", self.name, self.token));
send_msg(&tx, hello).await?;
let mut client = Self::connect(&self.origin).await?;
let resp = client.channel(ReceiverStream::new(rx)).await?;
let mut messages = resp.into_inner();
let mut interval = time::interval(HEARTBEAT_INTERVAL);
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
let mut reconnect = pin!(time::sleep(RECONNECT_INTERVAL));
loop {
let message = tokio::select! {
_ = interval.tick() => {
tx.send(ClientUpdate::default()).await?;
continue;
}
msg = self.output_rx.recv() => {
let msg = msg.context("unreachable: output_tx was closed?")?;
send_msg(&tx, msg).await?;
continue;
}
item = messages.next() => {
item.context("server closed connection")??
.server_message
.context("server message is missing")?
}
_ = &mut reconnect => {
return Ok(()); }
};
match message {
ServerMessage::Input(input) => {
let data = self.encrypt.segment(0x200000000, input.offset, &input.data);
if let Some(sender) = self.shells_tx.get(&Sid(input.id)) {
sender.send(ShellData::Data(data)).await.ok();
} else {
warn!(%input.id, "received data for non-existing shell");
}
}
ServerMessage::CreateShell(new_shell) => {
let id = Sid(new_shell.id);
let center = (new_shell.x, new_shell.y);
if !self.shells_tx.contains_key(&id) {
self.spawn_shell_task(id, center);
} else {
warn!(%id, "server asked to create duplicate shell");
}
}
ServerMessage::CloseShell(id) => {
self.shells_tx.remove(&Sid(id));
send_msg(&tx, ClientMessage::ClosedShell(id)).await?;
}
ServerMessage::Sync(seqnums) => {
for (id, seq) in seqnums.map {
if let Some(sender) = self.shells_tx.get(&Sid(id)) {
sender.send(ShellData::Sync(seq)).await.ok();
} else {
warn!(%id, "received sequence number for non-existing shell");
send_msg(&tx, ClientMessage::ClosedShell(id)).await?;
}
}
}
ServerMessage::Resize(msg) => {
if let Some(sender) = self.shells_tx.get(&Sid(msg.id)) {
sender.send(ShellData::Size(msg.rows, msg.cols)).await.ok();
} else {
warn!(%msg.id, "received resize for non-existing shell");
}
}
ServerMessage::Ping(ts) => {
send_msg(&tx, ClientMessage::Pong(ts)).await?;
}
ServerMessage::Error(err) => {
error!(?err, "error received from server");
}
}
}
}
fn spawn_shell_task(&mut self, id: Sid, center: (i32, i32)) {
let (shell_tx, shell_rx) = mpsc::channel(16);
let opt = self.shells_tx.insert(id, shell_tx);
debug_assert!(opt.is_none(), "shell ID cannot be in existing tasks");
let runner = self.runner.clone();
let encrypt = self.encrypt.clone();
let output_tx = self.output_tx.clone();
tokio::spawn(async move {
debug!(%id, "spawning new shell");
let new_shell = NewShell {
id: id.0,
x: center.0,
y: center.1,
};
if let Err(err) = output_tx.send(ClientMessage::CreatedShell(new_shell)).await {
error!(%id, ?err, "failed to send shell creation message");
return;
}
if let Err(err) = runner.run(id, encrypt, shell_rx, output_tx.clone()).await {
let err = ClientMessage::Error(err.to_string());
output_tx.send(err).await.ok();
}
output_tx.send(ClientMessage::ClosedShell(id.0)).await.ok();
});
}
pub async fn close(&self) -> Result<()> {
debug!("closing session");
let req = CloseRequest {
name: self.name.clone(),
token: self.token.clone(),
};
let mut client = Self::connect(&self.origin).await?;
client.close(req).await?;
Ok(())
}
}
async fn send_msg(tx: &mpsc::Sender<ClientUpdate>, message: ClientMessage) -> Result<()> {
let update = ClientUpdate {
client_message: Some(message),
};
tx.send(update)
.await
.context("failed to send message to server")
}