Skip to main content

rustgate/
ws.rs

1use crate::error::Result;
2use bytes::Bytes;
3use std::collections::{HashMap, HashSet};
4use std::sync::atomic::{AtomicU32, Ordering};
5use tokio::net::TcpStream;
6use tokio::sync::{mpsc, oneshot, Mutex};
7use tokio_rustls::client::TlsStream as ClientTlsStream;
8use tokio_rustls::server::TlsStream as ServerTlsStream;
9use tokio_tungstenite::WebSocketStream;
10
11pub type ServerWsStream = WebSocketStream<ServerTlsStream<TcpStream>>;
12pub type ClientWsStream = WebSocketStream<ClientTlsStream<TcpStream>>;
13
14/// Accept a WebSocket connection over an already-established mTLS server stream.
15pub async fn accept_ws(tls_stream: ServerTlsStream<TcpStream>) -> Result<ServerWsStream> {
16    let ws = tokio_tungstenite::accept_async(tls_stream).await?;
17    Ok(ws)
18}
19
20/// Connect as a WebSocket client over an already-established mTLS client stream.
21pub async fn connect_ws(
22    tls_stream: ClientTlsStream<TcpStream>,
23    url: &str,
24) -> Result<ClientWsStream> {
25    let (ws, _response) = tokio_tungstenite::client_async(url, tls_stream).await?;
26    Ok(ws)
27}
28
29/// Internal state protected by a single Mutex to prevent lock-order deadlocks.
30struct ChannelState {
31    channels: HashMap<u32, mpsc::Sender<Bytes>>,
32    ready_signals: HashMap<u32, oneshot::Sender<()>>,
33    tunnel_channels: HashMap<u32, HashSet<u32>>,
34}
35
36/// Manages multiplexed data channels over a single WebSocket connection.
37///
38/// Each channel has a unique u32 ID. Client-originated channels use odd IDs,
39/// server-originated channels use even IDs to avoid collisions.
40pub struct ChannelMap {
41    state: Mutex<ChannelState>,
42    next_id: AtomicU32,
43}
44
45impl ChannelMap {
46    /// Create a new ChannelMap. `start_id` should be 1 for clients (odd), 2 for servers (even).
47    pub fn new(start_id: u32) -> Self {
48        Self {
49            state: Mutex::new(ChannelState {
50                channels: HashMap::new(),
51                ready_signals: HashMap::new(),
52                tunnel_channels: HashMap::new(),
53            }),
54            next_id: AtomicU32::new(start_id),
55        }
56    }
57
58    /// Allocate the next channel ID (increments by 2 to maintain odd/even parity).
59    pub fn alloc_id(&self) -> u32 {
60        self.next_id.fetch_add(2, Ordering::Relaxed)
61    }
62
63    /// Check if a channel_id is already registered.
64    pub async fn has(&self, channel_id: u32) -> bool {
65        self.state.lock().await.channels.contains_key(&channel_id)
66    }
67
68    /// Register a channel with its sender.
69    pub async fn insert(&self, channel_id: u32, sender: mpsc::Sender<Bytes>) {
70        self.state.lock().await.channels.insert(channel_id, sender);
71    }
72
73    /// Register a channel and associate it with a tunnel_id for lifecycle tracking.
74    pub async fn insert_with_tunnel(
75        &self,
76        channel_id: u32,
77        tunnel_id: u32,
78        sender: mpsc::Sender<Bytes>,
79    ) {
80        let mut s = self.state.lock().await;
81        s.channels.insert(channel_id, sender);
82        s.tunnel_channels
83            .entry(tunnel_id)
84            .or_default()
85            .insert(channel_id);
86    }
87
88    /// Route data to a channel. Returns false if channel not found or closed.
89    /// Uses try_send so the shared WS reader never blocks on one slow channel.
90    /// If the buffer is full, the channel is closed cleanly (removed + returns false).
91    pub async fn send(&self, channel_id: u32, data: Bytes) -> bool {
92        let tx = {
93            let s = self.state.lock().await;
94            s.channels.get(&channel_id).cloned()
95        };
96        if let Some(tx) = tx {
97            match tx.try_send(data) {
98                Ok(()) => true,
99                Err(mpsc::error::TrySendError::Full(_)) => {
100                    // Channel congested — close it to preserve session liveness.
101                    // The relay task will see the sender drop and clean up.
102                    self.remove(channel_id).await;
103                    false
104                }
105                Err(mpsc::error::TrySendError::Closed(_)) => false,
106            }
107        } else {
108            false
109        }
110    }
111
112    /// Remove a channel and cancel any pending readiness waiter.
113    pub async fn remove(&self, channel_id: u32) {
114        let mut s = self.state.lock().await;
115        s.channels.remove(&channel_id);
116        s.ready_signals.remove(&channel_id);
117        for set in s.tunnel_channels.values_mut() {
118            set.remove(&channel_id);
119        }
120    }
121
122    /// Close ALL channels — used on session disconnect.
123    pub async fn close_all(&self) {
124        let mut s = self.state.lock().await;
125        s.channels.clear();
126        s.ready_signals.clear();
127        s.tunnel_channels.clear();
128    }
129
130    /// Close all channels belonging to a tunnel. Returns the channel IDs that were removed.
131    pub async fn close_tunnel(&self, tunnel_id: u32) -> Vec<u32> {
132        let mut s = self.state.lock().await;
133        let channel_ids: Vec<u32> = s
134            .tunnel_channels
135            .remove(&tunnel_id)
136            .unwrap_or_default()
137            .into_iter()
138            .collect();
139        for &id in &channel_ids {
140            s.channels.remove(&id);
141            s.ready_signals.remove(&id);
142        }
143        channel_ids
144    }
145
146    /// Register a readiness waiter for a channel.
147    pub async fn wait_ready(&self, channel_id: u32) -> oneshot::Receiver<()> {
148        let (tx, rx) = oneshot::channel();
149        self.state
150            .lock()
151            .await
152            .ready_signals
153            .insert(channel_id, tx);
154        rx
155    }
156
157    /// Signal that a channel is ready. Returns true if a waiter was notified.
158    pub async fn signal_ready(&self, channel_id: u32) -> bool {
159        if let Some(tx) = self.state.lock().await.ready_signals.remove(&channel_id) {
160            tx.send(()).is_ok()
161        } else {
162            false
163        }
164    }
165}