Skip to main content

websock_tungstenite/
connection.rs

1//! Connection management for the Tokio Tungstenite transport.
2
3use rustls::ClientConfig;
4use std::net::SocketAddr;
5use std::sync::Arc;
6use tokio::io::{AsyncRead, AsyncWrite};
7use tokio_tungstenite::{Connector, WebSocketStream, tungstenite};
8use tungstenite::client::IntoClientRequest;
9use websock_proto::{ConnectOptions, Error, Message, Result};
10
11#[derive(Debug, Clone, Copy)]
12pub struct ConnectionInfo {
13    /// Remote peer address for the connection.
14    pub peer: std::net::SocketAddr,
15    /// Local socket address for the connection.
16    pub local: std::net::SocketAddr,
17    /// True when the connection is established over TLS.
18    pub is_tls: bool,
19}
20
21/// Establish a WebSocket connection using Tokio Tungstenite.
22pub async fn connect(url: &str, opts: ConnectOptions) -> Result<Connection> {
23    connect_with_tls(url, opts, None).await
24}
25
26/// Establish a WebSocket connection using a custom TLS configuration.
27pub async fn connect_with_tls(
28    url: &str,
29    opts: ConnectOptions,
30    tls: Option<Arc<ClientConfig>>,
31) -> Result<Connection> {
32    let mut req = url
33        .into_client_request()
34        .map_err(|e| Error::InvalidUrl(e.to_string()))?;
35
36    // Apply configured headers and subprotocols.
37    {
38        let headers = req.headers_mut();
39        for (k, v) in opts.headers {
40            let name = tungstenite::http::header::HeaderName::from_bytes(k.as_bytes())
41                .map_err(|e| Error::Protocol(format!("invalid header name: {e}")))?;
42            let value = tungstenite::http::header::HeaderValue::from_str(&v)
43                .map_err(|e| Error::Protocol(format!("invalid header value: {e}")))?;
44            headers.append(name, value);
45        }
46
47        // Apply subprotocols.
48        if !opts.protocols.is_empty() {
49            let joined = opts.protocols.join(",");
50            let value = tungstenite::http::header::HeaderValue::from_str(&joined)
51                .map_err(|e| Error::Protocol(format!("invalid protocol value: {e}")))?;
52            headers.insert(tungstenite::http::header::SEC_WEBSOCKET_PROTOCOL, value);
53        }
54    }
55
56    let connector = tls.map(Connector::Rustls);
57    let (ws, _resp) = tokio_tungstenite::connect_async_tls_with_config(req, None, false, connector)
58        .await
59        .map_err(map_tungstenite_err)?;
60
61    let info = ConnectionInfo {
62        peer: ws
63            .get_ref()
64            .get_ref()
65            .peer_addr()
66            .map_err(|e| Error::Io(e.to_string()))?,
67        local: ws
68            .get_ref()
69            .get_ref()
70            .local_addr()
71            .map_err(|e| Error::Io(e.to_string()))?,
72        is_tls: matches!(ws.get_ref(), tokio_tungstenite::MaybeTlsStream::Rustls(_)),
73    };
74
75    Ok(Connection { ws, info })
76}
77
78/// WebSocket connection wrapper around a Tokio Tungstenite stream.
79pub struct Connection<S = tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>> {
80    pub(crate) ws: WebSocketStream<S>,
81    pub(crate) info: ConnectionInfo,
82}
83
84impl<S> Connection<S>
85where
86    S: AsyncRead + AsyncWrite + Unpin,
87{
88    /// Send a text or binary message.
89    pub async fn send(&mut self, msg: Message) -> Result<()> {
90        use futures_util::SinkExt;
91
92        let tmsg = match msg {
93            Message::Text(s) => tungstenite::Message::Text(s.into()),
94            Message::Binary(b) => tungstenite::Message::Binary(b),
95        };
96
97        self.ws.send(tmsg).await.map_err(map_tungstenite_err)?;
98        Ok(())
99    }
100
101    /// Receive the next text or binary message, responding to pings as needed.
102    pub async fn recv(&mut self) -> Result<Message> {
103        use futures_util::{SinkExt, StreamExt};
104
105        loop {
106            let item = self.ws.next().await.ok_or(Error::Closed)?;
107            let msg = item.map_err(map_tungstenite_err)?;
108
109            match msg {
110                tungstenite::Message::Ping(p) => {
111                    self.ws
112                        .send(tungstenite::Message::Pong(p))
113                        .await
114                        .map_err(map_tungstenite_err)?;
115                    continue;
116                }
117                tungstenite::Message::Pong(_) => continue,
118                tungstenite::Message::Text(s) => return Ok(Message::Text(s.to_string())),
119                tungstenite::Message::Binary(b) => return Ok(Message::Binary(b)),
120                tungstenite::Message::Close(_) => {
121                    let _ = self.ws.close(None).await;
122                    return Err(Error::Closed);
123                }
124                _ => return Err(Error::Protocol("unsupported ws message".into())),
125            }
126        }
127    }
128
129    /// Close the WebSocket connection gracefully.
130    pub async fn close(&mut self) -> Result<()> {
131        self.ws.close(None).await.map_err(map_tungstenite_err)?;
132        Ok(())
133    }
134
135    /// Borrow the underlying transport stream.
136    pub fn get_ref(&self) -> &S {
137        self.ws.get_ref()
138    }
139
140    /// Mutably borrow the underlying transport stream.
141    pub fn get_mut(&mut self) -> &mut S {
142        self.ws.get_mut()
143    }
144}
145
146impl<S> websock_proto::WebSocketConnection for Connection<S>
147where
148    S: AsyncRead + AsyncWrite + Unpin,
149{
150    fn send<'a>(&'a mut self, msg: Message) -> websock_proto::LocalBoxFuture<'a, Result<()>> {
151        Box::pin(async move { Connection::send(self, msg).await })
152    }
153
154    fn recv<'a>(&'a mut self) -> websock_proto::LocalBoxFuture<'a, Result<Message>> {
155        Box::pin(async move { Connection::recv(self).await })
156    }
157
158    fn close<'a>(&'a mut self) -> websock_proto::LocalBoxFuture<'a, Result<()>> {
159        Box::pin(async move { Connection::close(self).await })
160    }
161}
162
163impl<S> Connection<S> {
164    /// Return the peer address.
165    pub fn peer_addr(&self) -> SocketAddr {
166        self.info.peer
167    }
168    /// Return the local address.
169    pub fn local_addr(&self) -> SocketAddr {
170        self.info.local
171    }
172    /// Report whether TLS is in use.
173    pub fn is_tls(&self) -> bool {
174        self.info.is_tls
175    }
176    /// Return the full connection metadata snapshot.
177    pub fn info(&self) -> ConnectionInfo {
178        self.info
179    }
180}
181
182/// Map tungstenite errors into the shared error type.
183pub(crate) fn map_tungstenite_err(e: tungstenite::Error) -> Error {
184    use tungstenite::Error as E;
185    match e {
186        E::ConnectionClosed | E::AlreadyClosed => Error::Closed,
187        E::Io(io) => Error::Io(io.to_string()),
188        E::Tls(tls) => Error::Tls(tls.to_string()),
189        E::Url(url) => Error::InvalidUrl(url.to_string()),
190        E::Protocol(err) => Error::Protocol(err.to_string()),
191        E::Utf8(err) => Error::Protocol(err),
192        E::Capacity(err) => Error::Protocol(err.to_string()),
193        E::HttpFormat(err) => Error::Protocol(err.to_string()),
194        other => Error::Other(other.to_string()),
195    }
196}