actnel/
lib.rs

1use futures::channel::mpsc::{unbounded, UnboundedSender};
2use futures::{SinkExt, StreamExt};
3
4use tokio_tungstenite::tungstenite::Message;
5
6use log::{debug, error, info, warn};
7
8use std::collections::HashMap;
9use std::net::SocketAddr;
10use std::sync::{Arc, RwLock};
11
12pub mod config;
13pub mod error;
14mod local;
15
16use crate::config::*;
17use crate::error::*;
18use actnel_lib::*;
19pub use actnel_lib::DeviceId;
20
21use std::time::Duration;
22use tokio::sync::{mpsc, Mutex};
23
24pub type ActiveStreams = Arc<RwLock<HashMap<StreamId, UnboundedSender<StreamMessage>>>>;
25
26lazy_static::lazy_static! {
27    pub static ref ACTIVE_STREAMS:ActiveStreams = Arc::new(RwLock::new(HashMap::new()));
28    pub static ref RECONNECT_TOKEN: Arc<Mutex<Option<ReconnectToken>>> = Arc::new(Mutex::new(None));
29}
30
31#[derive(Debug, Clone)]
32pub enum StreamMessage {
33    Data(Vec<u8>),
34    Close,
35}
36
37pub struct Session {
38    config: Config,
39    wormhole: Wormhole,
40}
41
42impl Session {
43    pub async fn connect(config: Config) -> Result<Self, Error> {
44        // let config = match Config::get() {
45        //     Ok(config) => config,
46        //     Err(_) => return Err(Error::Timeout),
47        // };
48
49        let wormhole = Wormhole::connect(&config).await?;
50        Ok(Session {
51            config,
52            wormhole,
53        })
54    }
55
56    pub async fn listen(&self) -> Result<SocketAddr, Error> {
57        let config = self.config.clone();
58        let (restart_tx, _) = unbounded();
59        let _ = self.wormhole.listen(config, restart_tx).await;
60        // self.config.first_run = false;
61        Ok(self.config.local_addr.clone())
62    }
63
64    pub async fn close(&self) -> Result<(), Error> {
65        let _ = self.wormhole.close().await;
66        Ok(())
67    }
68
69    pub fn ingress_url(&self) -> String {
70        self.config.activation_url(self.wormhole.hostname.as_str())
71    }
72
73    pub fn quotas(&self) -> ClientQuotas {
74        self.wormhole.quotas.clone()
75    }
76}
77
78struct Wormhole {
79    sender: mpsc::UnboundedSender<Message>,
80    receiver: Arc<Mutex<mpsc::UnboundedReceiver<Message>>>,
81    sub_domain: String,
82    hostname: String,
83    quotas: ClientQuotas,
84}
85
86impl Wormhole {
87    // Function to create a new Wormhole connection
88    async fn connect(config: &Config) -> Result<Self, Error> {
89        let (mut websocket, _) = tokio_tungstenite::connect_async(&config.control_url).await?;
90
91        // send our Client Hello message
92        let client_hello = match config.secret_key.clone() {
93            Some(secret_key) => ClientHello::generate(
94                config.sub_domain.clone(),
95                ClientType::Auth { key: secret_key },
96            ),
97            None => {
98                // if we have a reconnect token, use it.
99                if let Some(reconnect) = RECONNECT_TOKEN.lock().await.clone() {
100                    ClientHello::reconnect(reconnect)
101                } else {
102                    ClientHello::generate(config.sub_domain.clone(), ClientType::Anonymous)
103                }
104            }
105        };
106
107        info!("connecting to wormhole...");
108
109        let hello = serde_json::to_vec(&client_hello).unwrap();
110        websocket
111            .send(Message::binary(hello))
112            .await
113            .expect("Failed to send client hello to wormhole server.");
114
115        // wait for Server hello
116        let server_hello_data = websocket
117            .next()
118            .await
119            .ok_or(Error::NoResponseFromServer)??
120            .into_data();
121        let server_hello = serde_json::from_slice::<ServerHello>(&server_hello_data).map_err(|e| {
122            error!("Couldn't parse server_hello from {:?}", e);
123            Error::ServerReplyInvalid
124        })?;
125
126        let (sub_domain, hostname, quotas) = match server_hello {
127            ServerHello::Success {
128                sub_domain,
129                client_id,
130                hostname,
131                quotas,
132            } => {
133                info!("Server accepted our connection. I am client_{}", client_id);
134                (sub_domain, hostname, quotas)
135            }
136            ServerHello::AuthFailed => {
137                return Err(Error::AuthenticationFailed);
138            }
139            ServerHello::InvalidSubDomain => {
140                return Err(Error::InvalidSubDomain);
141            }
142            ServerHello::SubDomainInUse => {
143                return Err(Error::SubDomainInUse);
144            }
145            ServerHello::Error(error) => return Err(Error::ServerError(error)),
146        };
147
148        let (receive_tx, receive_rx) = mpsc::unbounded_channel();
149        let (send_tx, mut send_rx) = mpsc::unbounded_channel();
150        // Spawn a task to handle the WebSocket
151        tokio::spawn({
152            async move {
153                let mut ws_stream = websocket;
154                loop {
155                    tokio::select! {
156                        message = ws_stream.next() => {
157                            match message {
158                                Some(Ok(msg)) => {
159                                    if receive_tx.send(msg).is_err() {
160                                        break; // Channel closed
161                                    }
162                                }
163                                Some(Err(e)) => { // WebSocket error
164                                    warn!("websocket read error: {:?}", e);
165                                    break;
166                                },
167                                None => { // WebSocket closed
168                                    warn!("websocket sent none");
169                                    break;
170                                },
171                            }
172                        }
173                        received = async {
174                            send_rx.recv().await
175                        } => {
176                            // received is the result of locked_receiver.recv().await
177                            if let Some(msg) = received {
178                                if ws_stream.send(msg).await.is_err() {
179                                    break; // WebSocket error or closed
180                                }
181                            } else {
182                                break; // Channel closed
183                            }
184                        }
185                    }
186                }
187            }
188        });
189
190        Ok(Wormhole {
191            sender: send_tx,
192            receiver: Arc::new(Mutex::new(receive_rx)),
193            sub_domain,
194            hostname,
195            quotas,
196        })
197    }
198
199    // Send a message through the WebSocket
200    pub async fn send_message(&self, message: Message) -> Result<(), mpsc::error::SendError<Message>> {
201        self.sender.send(message)
202    }
203
204    // Receive a message from the WebSocket
205    pub async fn receive_message(&self) -> Option<Message> {
206        self.receiver.lock().await.recv().await
207    }
208
209    // Close the WebSocket
210    pub async fn close(&self) -> Result<(), mpsc::error::SendError<Message>> {
211        // Set the reconnect token to None
212        let _ = RECONNECT_TOKEN.lock().await.take();
213        // Send a close message
214        self.sender.send(Message::Close(None))
215        // solution for Protocol(ResetWithoutClosingHandshake)
216        // but doesn't work
217        // let _ = self.sender.send(Message::Close(None));
218        // let mut non_close_message_count = 0;
219        // while let Some(message) = self.receiver.lock().await.recv().await {
220        //     match message {
221        //         Message::Close(_) => {
222        //             // Acknowledgment received, break the loop
223        //             break;
224        //         }
225        //         _ => {
226        //             // Log other messages if necessary, but keep waiting for the close acknowledgment
227        //             debug!("Received non-close message during closing process: {:?}", message);
228        //             non_close_message_count += 1;
229        //             if non_close_message_count >= 10 {
230        //                 // Error out for timeout
231        //                 break;
232        //             }
233        //         }
234        //     }
235        // }
236        // Ok(())
237    }
238
239    async fn listen(&self, config: Config, restart_tx: UnboundedSender<Option<Error>>) -> Result<(), Error> {
240        // tunnel channel
241        let (tunnel_tx, mut tunnel_rx) = unbounded::<ControlPacket>();
242
243        // continuously write to websocket tunnel
244        let mut restart = restart_tx.clone();
245        let sender_clone = self.sender.clone();
246        tokio::spawn(async move {
247            while let Some(packet) = tunnel_rx.next().await {
248                let message = Message::binary(packet.serialize());  // Assuming ControlPacket has a serialize method
249                match sender_clone.send(message) {
250                    Ok(_) => {}  // Successfully sent the message
251                    Err(e) => {
252                        // Handle the error
253                        warn!("Failed to send message to WebSocket tunnel: {:?}", e);
254                        // let _ = restart_tx.send(Some(Error::WebSocketError(e))).await;
255                        let _ = restart.send(Some(Error::Timeout)).await;
256                        return;
257                    }
258                }
259            }
260        });
261
262        // continuously read from websocket tunnel
263        let mut restart = restart_tx.clone();
264        let receiver_clone = self.receiver.clone();
265        tokio::spawn(async move {
266            loop {
267                let mut receiver = receiver_clone.lock().await;
268                match receiver.recv().await {
269                    Some(message) if message.is_close() => {
270                        debug!("got close message");
271                        let _ = restart.send(None).await;
272                        return Ok(());
273                    }
274                    Some(message) => {
275                        let packet = process_control_flow_message(
276                            config.clone(),
277                            tunnel_tx.clone(),
278                            message.into_data(),
279                        )
280                        .await
281                        .map_err(|e| {
282                            error!("Malformed protocol control packet: {:?}", e);
283                            Error::MalformedMessageFromServer
284                        })?;
285                        debug!("Processed packet: {:?}", packet.packet_type());
286                    }
287                    None => {
288                        warn!("websocket sent none");
289                        return Err(Error::Timeout);
290                    }
291                }
292            }
293        });
294
295        Ok(())
296    }
297}
298
299async fn process_control_flow_message(
300    config: Config,
301    mut tunnel_tx: UnboundedSender<ControlPacket>,
302    payload: Vec<u8>,
303) -> Result<ControlPacket, Box<dyn std::error::Error>> {
304    let control_packet = ControlPacket::deserialize(&payload)?;
305
306    match &control_packet {
307        ControlPacket::Init(stream_id) => {
308            info!("stream[{:?}] -> init", stream_id.to_string());
309        }
310        ControlPacket::Ping(reconnect_token) => {
311            log::info!("got ping. reconnect_token={}", reconnect_token.is_some());
312
313            if let Some(reconnect) = reconnect_token {
314                let _ = RECONNECT_TOKEN.lock().await.replace(reconnect.clone());
315            }
316            let _ = tunnel_tx.send(ControlPacket::Ping(None)).await;
317        }
318        ControlPacket::Refused(_) => return Err("unexpected control packet".into()),
319        ControlPacket::End(stream_id) => {
320            // find the stream
321            let stream_id = stream_id.clone();
322
323            info!("got end stream [{:?}]", &stream_id);
324
325            tokio::spawn(async move {
326                let stream = ACTIVE_STREAMS.read().unwrap().get(&stream_id).cloned();
327                if let Some(mut tx) = stream {
328                    tokio::time::sleep(Duration::from_secs(5)).await;
329                    let _ = tx.send(StreamMessage::Close).await.map_err(|e| {
330                        error!("failed to send stream close: {:?}", e);
331                    });
332                    ACTIVE_STREAMS.write().unwrap().remove(&stream_id);
333                }
334            });
335        }
336        ControlPacket::Data(stream_id, data) => {
337            info!(
338                "stream[{:?}] -> new data: {:?}",
339                stream_id.to_string(),
340                data.len()
341            );
342
343            if !ACTIVE_STREAMS.read().unwrap().contains_key(&stream_id) {
344                if local::setup_new_stream(config.clone(), tunnel_tx.clone(), stream_id.clone())
345                    .await
346                    .is_none()
347                {
348                    error!("failed to open local tunnel")
349                }
350            }
351
352            // find the right stream
353            let active_stream = ACTIVE_STREAMS.read().unwrap().get(&stream_id).cloned();
354
355            // forward data to it
356            if let Some(mut tx) = active_stream {
357                tx.send(StreamMessage::Data(data.clone())).await?;
358                info!("forwarded to local tcp ({})", stream_id.to_string());
359            } else {
360                error!("got data but no stream to send it to.");
361                let _ = tunnel_tx
362                    .send(ControlPacket::Refused(stream_id.clone()))
363                    .await?;
364            }
365        }
366    };
367
368    Ok(control_packet.clone())
369}