1use futures::channel::mpsc::{unbounded, UnboundedSender};
2use futures::{SinkExt, StreamExt};
3
4use tokio_tungstenite::tungstenite::Message;
5
6use log::{debug, error, info, warn};
7
8use std::collections::HashMap;
9use std::net::SocketAddr;
10use std::sync::{Arc, RwLock};
11
12pub mod config;
13pub mod error;
14mod local;
15
16use crate::config::*;
17use crate::error::*;
18use actnel_lib::*;
19pub use actnel_lib::DeviceId;
20
21use std::time::Duration;
22use tokio::sync::{mpsc, Mutex};
23
24pub type ActiveStreams = Arc<RwLock<HashMap<StreamId, UnboundedSender<StreamMessage>>>>;
25
26lazy_static::lazy_static! {
27 pub static ref ACTIVE_STREAMS:ActiveStreams = Arc::new(RwLock::new(HashMap::new()));
28 pub static ref RECONNECT_TOKEN: Arc<Mutex<Option<ReconnectToken>>> = Arc::new(Mutex::new(None));
29}
30
31#[derive(Debug, Clone)]
32pub enum StreamMessage {
33 Data(Vec<u8>),
34 Close,
35}
36
37pub struct Session {
38 config: Config,
39 wormhole: Wormhole,
40}
41
42impl Session {
43 pub async fn connect(config: Config) -> Result<Self, Error> {
44 let wormhole = Wormhole::connect(&config).await?;
50 Ok(Session {
51 config,
52 wormhole,
53 })
54 }
55
56 pub async fn listen(&self) -> Result<SocketAddr, Error> {
57 let config = self.config.clone();
58 let (restart_tx, _) = unbounded();
59 let _ = self.wormhole.listen(config, restart_tx).await;
60 Ok(self.config.local_addr.clone())
62 }
63
64 pub async fn close(&self) -> Result<(), Error> {
65 let _ = self.wormhole.close().await;
66 Ok(())
67 }
68
69 pub fn ingress_url(&self) -> String {
70 self.config.activation_url(self.wormhole.hostname.as_str())
71 }
72
73 pub fn quotas(&self) -> ClientQuotas {
74 self.wormhole.quotas.clone()
75 }
76}
77
78struct Wormhole {
79 sender: mpsc::UnboundedSender<Message>,
80 receiver: Arc<Mutex<mpsc::UnboundedReceiver<Message>>>,
81 sub_domain: String,
82 hostname: String,
83 quotas: ClientQuotas,
84}
85
86impl Wormhole {
87 async fn connect(config: &Config) -> Result<Self, Error> {
89 let (mut websocket, _) = tokio_tungstenite::connect_async(&config.control_url).await?;
90
91 let client_hello = match config.secret_key.clone() {
93 Some(secret_key) => ClientHello::generate(
94 config.sub_domain.clone(),
95 ClientType::Auth { key: secret_key },
96 ),
97 None => {
98 if let Some(reconnect) = RECONNECT_TOKEN.lock().await.clone() {
100 ClientHello::reconnect(reconnect)
101 } else {
102 ClientHello::generate(config.sub_domain.clone(), ClientType::Anonymous)
103 }
104 }
105 };
106
107 info!("connecting to wormhole...");
108
109 let hello = serde_json::to_vec(&client_hello).unwrap();
110 websocket
111 .send(Message::binary(hello))
112 .await
113 .expect("Failed to send client hello to wormhole server.");
114
115 let server_hello_data = websocket
117 .next()
118 .await
119 .ok_or(Error::NoResponseFromServer)??
120 .into_data();
121 let server_hello = serde_json::from_slice::<ServerHello>(&server_hello_data).map_err(|e| {
122 error!("Couldn't parse server_hello from {:?}", e);
123 Error::ServerReplyInvalid
124 })?;
125
126 let (sub_domain, hostname, quotas) = match server_hello {
127 ServerHello::Success {
128 sub_domain,
129 client_id,
130 hostname,
131 quotas,
132 } => {
133 info!("Server accepted our connection. I am client_{}", client_id);
134 (sub_domain, hostname, quotas)
135 }
136 ServerHello::AuthFailed => {
137 return Err(Error::AuthenticationFailed);
138 }
139 ServerHello::InvalidSubDomain => {
140 return Err(Error::InvalidSubDomain);
141 }
142 ServerHello::SubDomainInUse => {
143 return Err(Error::SubDomainInUse);
144 }
145 ServerHello::Error(error) => return Err(Error::ServerError(error)),
146 };
147
148 let (receive_tx, receive_rx) = mpsc::unbounded_channel();
149 let (send_tx, mut send_rx) = mpsc::unbounded_channel();
150 tokio::spawn({
152 async move {
153 let mut ws_stream = websocket;
154 loop {
155 tokio::select! {
156 message = ws_stream.next() => {
157 match message {
158 Some(Ok(msg)) => {
159 if receive_tx.send(msg).is_err() {
160 break; }
162 }
163 Some(Err(e)) => { warn!("websocket read error: {:?}", e);
165 break;
166 },
167 None => { warn!("websocket sent none");
169 break;
170 },
171 }
172 }
173 received = async {
174 send_rx.recv().await
175 } => {
176 if let Some(msg) = received {
178 if ws_stream.send(msg).await.is_err() {
179 break; }
181 } else {
182 break; }
184 }
185 }
186 }
187 }
188 });
189
190 Ok(Wormhole {
191 sender: send_tx,
192 receiver: Arc::new(Mutex::new(receive_rx)),
193 sub_domain,
194 hostname,
195 quotas,
196 })
197 }
198
199 pub async fn send_message(&self, message: Message) -> Result<(), mpsc::error::SendError<Message>> {
201 self.sender.send(message)
202 }
203
204 pub async fn receive_message(&self) -> Option<Message> {
206 self.receiver.lock().await.recv().await
207 }
208
209 pub async fn close(&self) -> Result<(), mpsc::error::SendError<Message>> {
211 let _ = RECONNECT_TOKEN.lock().await.take();
213 self.sender.send(Message::Close(None))
215 }
238
239 async fn listen(&self, config: Config, restart_tx: UnboundedSender<Option<Error>>) -> Result<(), Error> {
240 let (tunnel_tx, mut tunnel_rx) = unbounded::<ControlPacket>();
242
243 let mut restart = restart_tx.clone();
245 let sender_clone = self.sender.clone();
246 tokio::spawn(async move {
247 while let Some(packet) = tunnel_rx.next().await {
248 let message = Message::binary(packet.serialize()); match sender_clone.send(message) {
250 Ok(_) => {} Err(e) => {
252 warn!("Failed to send message to WebSocket tunnel: {:?}", e);
254 let _ = restart.send(Some(Error::Timeout)).await;
256 return;
257 }
258 }
259 }
260 });
261
262 let mut restart = restart_tx.clone();
264 let receiver_clone = self.receiver.clone();
265 tokio::spawn(async move {
266 loop {
267 let mut receiver = receiver_clone.lock().await;
268 match receiver.recv().await {
269 Some(message) if message.is_close() => {
270 debug!("got close message");
271 let _ = restart.send(None).await;
272 return Ok(());
273 }
274 Some(message) => {
275 let packet = process_control_flow_message(
276 config.clone(),
277 tunnel_tx.clone(),
278 message.into_data(),
279 )
280 .await
281 .map_err(|e| {
282 error!("Malformed protocol control packet: {:?}", e);
283 Error::MalformedMessageFromServer
284 })?;
285 debug!("Processed packet: {:?}", packet.packet_type());
286 }
287 None => {
288 warn!("websocket sent none");
289 return Err(Error::Timeout);
290 }
291 }
292 }
293 });
294
295 Ok(())
296 }
297}
298
299async fn process_control_flow_message(
300 config: Config,
301 mut tunnel_tx: UnboundedSender<ControlPacket>,
302 payload: Vec<u8>,
303) -> Result<ControlPacket, Box<dyn std::error::Error>> {
304 let control_packet = ControlPacket::deserialize(&payload)?;
305
306 match &control_packet {
307 ControlPacket::Init(stream_id) => {
308 info!("stream[{:?}] -> init", stream_id.to_string());
309 }
310 ControlPacket::Ping(reconnect_token) => {
311 log::info!("got ping. reconnect_token={}", reconnect_token.is_some());
312
313 if let Some(reconnect) = reconnect_token {
314 let _ = RECONNECT_TOKEN.lock().await.replace(reconnect.clone());
315 }
316 let _ = tunnel_tx.send(ControlPacket::Ping(None)).await;
317 }
318 ControlPacket::Refused(_) => return Err("unexpected control packet".into()),
319 ControlPacket::End(stream_id) => {
320 let stream_id = stream_id.clone();
322
323 info!("got end stream [{:?}]", &stream_id);
324
325 tokio::spawn(async move {
326 let stream = ACTIVE_STREAMS.read().unwrap().get(&stream_id).cloned();
327 if let Some(mut tx) = stream {
328 tokio::time::sleep(Duration::from_secs(5)).await;
329 let _ = tx.send(StreamMessage::Close).await.map_err(|e| {
330 error!("failed to send stream close: {:?}", e);
331 });
332 ACTIVE_STREAMS.write().unwrap().remove(&stream_id);
333 }
334 });
335 }
336 ControlPacket::Data(stream_id, data) => {
337 info!(
338 "stream[{:?}] -> new data: {:?}",
339 stream_id.to_string(),
340 data.len()
341 );
342
343 if !ACTIVE_STREAMS.read().unwrap().contains_key(&stream_id) {
344 if local::setup_new_stream(config.clone(), tunnel_tx.clone(), stream_id.clone())
345 .await
346 .is_none()
347 {
348 error!("failed to open local tunnel")
349 }
350 }
351
352 let active_stream = ACTIVE_STREAMS.read().unwrap().get(&stream_id).cloned();
354
355 if let Some(mut tx) = active_stream {
357 tx.send(StreamMessage::Data(data.clone())).await?;
358 info!("forwarded to local tcp ({})", stream_id.to_string());
359 } else {
360 error!("got data but no stream to send it to.");
361 let _ = tunnel_tx
362 .send(ControlPacket::Refused(stream_id.clone()))
363 .await?;
364 }
365 }
366 };
367
368 Ok(control_packet.clone())
369}