use std::collections::{BTreeMap, HashMap};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{ensure, Result};
use futures_util::{SinkExt, StreamExt};
use hyper::{server::conn::AddrIncoming, StatusCode};
use sshx::encrypt::Encrypt;
use sshx_core::proto::sshx_service_client::SshxServiceClient;
use sshx_core::{Sid, Uid};
use sshx_server::{
state::ServerState,
web::protocol::{WsClient, WsServer, WsUser, WsWinsize},
Server,
};
use tokio::net::{TcpListener, TcpStream};
use tokio::time;
use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
use tonic::transport::Channel;
pub struct TestServer {
local_addr: SocketAddr,
server: Arc<Server>,
}
impl TestServer {
pub async fn new() -> Self {
let listener = TcpListener::bind("[::1]:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let incoming = AddrIncoming::from_listener(listener).unwrap();
let server = Arc::new(Server::new(Default::default()).unwrap());
{
let server = Arc::clone(&server);
tokio::spawn(async move {
server.listen(incoming).await.unwrap();
});
}
TestServer { local_addr, server }
}
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub fn endpoint(&self) -> String {
format!("http://{}", self.local_addr)
}
pub fn ws_endpoint(&self, name: &str) -> String {
format!("ws://{}/api/s/{}", self.local_addr, name)
}
pub async fn grpc_client(&self) -> SshxServiceClient<Channel> {
SshxServiceClient::connect(self.endpoint()).await.unwrap()
}
pub fn state(&self) -> Arc<ServerState> {
self.server.state()
}
}
impl Drop for TestServer {
fn drop(&mut self) {
self.server.shutdown();
}
}
pub struct ClientSocket {
inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
encrypt: Encrypt,
pub user_id: Uid,
pub users: BTreeMap<Uid, WsUser>,
pub shells: BTreeMap<Sid, WsWinsize>,
pub data: HashMap<Sid, String>,
pub messages: Vec<(Uid, String, String)>,
pub errors: Vec<String>,
}
impl ClientSocket {
pub async fn connect(uri: &str, key: &str) -> Result<Self> {
let (stream, resp) = tokio_tungstenite::connect_async(uri).await?;
ensure!(resp.status() == StatusCode::SWITCHING_PROTOCOLS);
let mut this = Self {
inner: stream,
encrypt: Encrypt::new(key),
user_id: Uid(0),
users: BTreeMap::new(),
shells: BTreeMap::new(),
data: HashMap::new(),
messages: Vec::new(),
errors: Vec::new(),
};
this.authenticate().await;
Ok(this)
}
async fn authenticate(&mut self) {
let encrypted_zeros = self.encrypt.zeros().into();
self.send(WsClient::Authenticate(encrypted_zeros)).await;
}
pub async fn send(&mut self, msg: WsClient) {
let mut buf = Vec::new();
ciborium::ser::into_writer(&msg, &mut buf).unwrap();
self.inner.send(Message::Binary(buf)).await.unwrap();
}
pub async fn send_input(&mut self, id: Sid, data: &[u8]) {
let offset = 42; let data = self.encrypt.segment(0x200000000, offset, data);
self.send(WsClient::Data(id, data.into(), offset)).await;
}
async fn recv(&mut self) -> Option<WsServer> {
loop {
match self.inner.next().await.transpose().unwrap() {
Some(Message::Text(_)) => panic!("unexpected text message over WebSocket"),
Some(Message::Binary(msg)) => {
break Some(ciborium::de::from_reader(&*msg).unwrap())
}
Some(_) => (), None => break None,
}
}
}
pub async fn expect_close(&mut self, code: u16) {
let msg = self.inner.next().await.unwrap().unwrap();
match msg {
Message::Close(Some(frame)) => assert!(frame.code == code.into()),
_ => panic!("unexpected non-close message over WebSocket: {:?}", msg),
}
}
pub async fn flush(&mut self) {
const FLUSH_DURATION: Duration = Duration::from_millis(50);
let flush_task = async {
while let Some(msg) = self.recv().await {
match msg {
WsServer::Hello(user_id, _) => self.user_id = user_id,
WsServer::InvalidAuth() => panic!("invalid authentication"),
WsServer::Users(users) => self.users = BTreeMap::from_iter(users),
WsServer::UserDiff(id, maybe_user) => {
self.users.remove(&id);
if let Some(user) = maybe_user {
self.users.insert(id, user);
}
}
WsServer::Shells(shells) => self.shells = BTreeMap::from_iter(shells),
WsServer::Chunks(id, seqnum, chunks) => {
let value = self.data.entry(id).or_default();
assert_eq!(seqnum, value.len() as u64);
for buf in chunks {
let plaintext = self.encrypt.segment(
0x100000000 | id.0 as u64,
value.len() as u64,
&buf,
);
value.push_str(std::str::from_utf8(&plaintext).unwrap());
}
}
WsServer::Hear(id, name, msg) => {
self.messages.push((id, name, msg));
}
WsServer::ShellLatency(_) => {}
WsServer::Pong(_) => {}
WsServer::Error(err) => self.errors.push(err),
}
}
};
time::timeout(FLUSH_DURATION, flush_task).await.ok();
}
pub fn read(&self, id: Sid) -> &str {
self.data.get(&id).map(|s| &**s).unwrap_or("")
}
}