use futures::channel::mpsc::{unbounded, UnboundedSender};
use futures::{SinkExt, StreamExt};
use tokio_tungstenite::tungstenite::Message;
use log::{debug, error, info, warn};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
pub mod config;
pub mod error;
mod local;
use crate::config::*;
use crate::error::*;
use actnel_lib::*;
pub use actnel_lib::DeviceId;
use std::time::Duration;
use tokio::sync::{mpsc, Mutex};
pub type ActiveStreams = Arc<RwLock<HashMap<StreamId, UnboundedSender<StreamMessage>>>>;
lazy_static::lazy_static! {
pub static ref ACTIVE_STREAMS:ActiveStreams = Arc::new(RwLock::new(HashMap::new()));
pub static ref RECONNECT_TOKEN: Arc<Mutex<Option<ReconnectToken>>> = Arc::new(Mutex::new(None));
}
#[derive(Debug, Clone)]
pub enum StreamMessage {
Data(Vec<u8>),
Close,
}
pub struct Session {
config: Config,
wormhole: Wormhole,
}
impl Session {
pub async fn connect(config: Config) -> Result<Self, Error> {
let wormhole = Wormhole::connect(&config).await?;
Ok(Session {
config,
wormhole,
})
}
pub async fn listen(&self) -> Result<SocketAddr, Error> {
let config = self.config.clone();
let (restart_tx, _) = unbounded();
let _ = self.wormhole.listen(config, restart_tx).await;
Ok(self.config.local_addr.clone())
}
pub async fn close(&self) -> Result<(), Error> {
let _ = self.wormhole.close().await;
Ok(())
}
pub fn ingress_url(&self) -> String {
self.config.activation_url(self.wormhole.hostname.as_str())
}
pub fn quotas(&self) -> ClientQuotas {
self.wormhole.quotas.clone()
}
}
struct Wormhole {
sender: mpsc::UnboundedSender<Message>,
receiver: Arc<Mutex<mpsc::UnboundedReceiver<Message>>>,
sub_domain: String,
hostname: String,
quotas: ClientQuotas,
}
impl Wormhole {
async fn connect(config: &Config) -> Result<Self, Error> {
let (mut websocket, _) = tokio_tungstenite::connect_async(&config.control_url).await?;
let client_hello = match config.secret_key.clone() {
Some(secret_key) => ClientHello::generate(
config.sub_domain.clone(),
ClientType::Auth { key: secret_key },
),
None => {
if let Some(reconnect) = RECONNECT_TOKEN.lock().await.clone() {
ClientHello::reconnect(reconnect)
} else {
ClientHello::generate(config.sub_domain.clone(), ClientType::Anonymous)
}
}
};
info!("connecting to wormhole...");
let hello = serde_json::to_vec(&client_hello).unwrap();
websocket
.send(Message::binary(hello))
.await
.expect("Failed to send client hello to wormhole server.");
let server_hello_data = websocket
.next()
.await
.ok_or(Error::NoResponseFromServer)??
.into_data();
let server_hello = serde_json::from_slice::<ServerHello>(&server_hello_data).map_err(|e| {
error!("Couldn't parse server_hello from {:?}", e);
Error::ServerReplyInvalid
})?;
let (sub_domain, hostname, quotas) = match server_hello {
ServerHello::Success {
sub_domain,
client_id,
hostname,
quotas,
} => {
info!("Server accepted our connection. I am client_{}", client_id);
(sub_domain, hostname, quotas)
}
ServerHello::AuthFailed => {
return Err(Error::AuthenticationFailed);
}
ServerHello::InvalidSubDomain => {
return Err(Error::InvalidSubDomain);
}
ServerHello::SubDomainInUse => {
return Err(Error::SubDomainInUse);
}
ServerHello::Error(error) => return Err(Error::ServerError(error)),
};
let (receive_tx, receive_rx) = mpsc::unbounded_channel();
let (send_tx, mut send_rx) = mpsc::unbounded_channel();
tokio::spawn({
async move {
let mut ws_stream = websocket;
loop {
tokio::select! {
message = ws_stream.next() => {
match message {
Some(Ok(msg)) => {
if receive_tx.send(msg).is_err() {
break; }
}
Some(Err(e)) => { warn!("websocket read error: {:?}", e);
break;
},
None => { warn!("websocket sent none");
break;
},
}
}
received = async {
send_rx.recv().await
} => {
if let Some(msg) = received {
if ws_stream.send(msg).await.is_err() {
break; }
} else {
break; }
}
}
}
}
});
Ok(Wormhole {
sender: send_tx,
receiver: Arc::new(Mutex::new(receive_rx)),
sub_domain,
hostname,
quotas,
})
}
pub async fn send_message(&self, message: Message) -> Result<(), mpsc::error::SendError<Message>> {
self.sender.send(message)
}
pub async fn receive_message(&self) -> Option<Message> {
self.receiver.lock().await.recv().await
}
pub async fn close(&self) -> Result<(), mpsc::error::SendError<Message>> {
let _ = RECONNECT_TOKEN.lock().await.take();
self.sender.send(Message::Close(None))
}
async fn listen(&self, config: Config, restart_tx: UnboundedSender<Option<Error>>) -> Result<(), Error> {
let (tunnel_tx, mut tunnel_rx) = unbounded::<ControlPacket>();
let mut restart = restart_tx.clone();
let sender_clone = self.sender.clone();
tokio::spawn(async move {
while let Some(packet) = tunnel_rx.next().await {
let message = Message::binary(packet.serialize()); match sender_clone.send(message) {
Ok(_) => {} Err(e) => {
warn!("Failed to send message to WebSocket tunnel: {:?}", e);
let _ = restart.send(Some(Error::Timeout)).await;
return;
}
}
}
});
let mut restart = restart_tx.clone();
let receiver_clone = self.receiver.clone();
tokio::spawn(async move {
loop {
let mut receiver = receiver_clone.lock().await;
match receiver.recv().await {
Some(message) if message.is_close() => {
debug!("got close message");
let _ = restart.send(None).await;
return Ok(());
}
Some(message) => {
let packet = process_control_flow_message(
config.clone(),
tunnel_tx.clone(),
message.into_data(),
)
.await
.map_err(|e| {
error!("Malformed protocol control packet: {:?}", e);
Error::MalformedMessageFromServer
})?;
debug!("Processed packet: {:?}", packet.packet_type());
}
None => {
warn!("websocket sent none");
return Err(Error::Timeout);
}
}
}
});
Ok(())
}
}
async fn process_control_flow_message(
config: Config,
mut tunnel_tx: UnboundedSender<ControlPacket>,
payload: Vec<u8>,
) -> Result<ControlPacket, Box<dyn std::error::Error>> {
let control_packet = ControlPacket::deserialize(&payload)?;
match &control_packet {
ControlPacket::Init(stream_id) => {
info!("stream[{:?}] -> init", stream_id.to_string());
}
ControlPacket::Ping(reconnect_token) => {
log::info!("got ping. reconnect_token={}", reconnect_token.is_some());
if let Some(reconnect) = reconnect_token {
let _ = RECONNECT_TOKEN.lock().await.replace(reconnect.clone());
}
let _ = tunnel_tx.send(ControlPacket::Ping(None)).await;
}
ControlPacket::Refused(_) => return Err("unexpected control packet".into()),
ControlPacket::End(stream_id) => {
let stream_id = stream_id.clone();
info!("got end stream [{:?}]", &stream_id);
tokio::spawn(async move {
let stream = ACTIVE_STREAMS.read().unwrap().get(&stream_id).cloned();
if let Some(mut tx) = stream {
tokio::time::sleep(Duration::from_secs(5)).await;
let _ = tx.send(StreamMessage::Close).await.map_err(|e| {
error!("failed to send stream close: {:?}", e);
});
ACTIVE_STREAMS.write().unwrap().remove(&stream_id);
}
});
}
ControlPacket::Data(stream_id, data) => {
info!(
"stream[{:?}] -> new data: {:?}",
stream_id.to_string(),
data.len()
);
if !ACTIVE_STREAMS.read().unwrap().contains_key(&stream_id) {
if local::setup_new_stream(config.clone(), tunnel_tx.clone(), stream_id.clone())
.await
.is_none()
{
error!("failed to open local tunnel")
}
}
let active_stream = ACTIVE_STREAMS.read().unwrap().get(&stream_id).cloned();
if let Some(mut tx) = active_stream {
tx.send(StreamMessage::Data(data.clone())).await?;
info!("forwarded to local tcp ({})", stream_id.to_string());
} else {
error!("got data but no stream to send it to.");
let _ = tunnel_tx
.send(ControlPacket::Refused(stream_id.clone()))
.await?;
}
}
};
Ok(control_packet.clone())
}