Skip to main content

cdk_http_client/ws/
native.rs

1//! Native WebSocket implementation using tokio-tungstenite
2
3use futures::{SinkExt, StreamExt};
4use tokio_tungstenite::tungstenite::client::IntoClientRequest;
5use tokio_tungstenite::tungstenite::Message;
6
7use super::WsError;
8
9/// WebSocket sender half
10pub struct WsSender {
11    inner: Box<
12        dyn futures::Sink<Message, Error = tokio_tungstenite::tungstenite::Error> + Unpin + Send,
13    >,
14}
15
16impl std::fmt::Debug for WsSender {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        f.debug_struct("WsSender").finish_non_exhaustive()
19    }
20}
21
22/// WebSocket receiver half
23pub struct WsReceiver {
24    inner: Box<
25        dyn futures::Stream<Item = Result<Message, tokio_tungstenite::tungstenite::Error>>
26            + Unpin
27            + Send,
28    >,
29}
30
31impl std::fmt::Debug for WsReceiver {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        f.debug_struct("WsReceiver").finish_non_exhaustive()
34    }
35}
36
37impl WsSender {
38    /// Send a text message over the WebSocket
39    pub async fn send(&mut self, text: String) -> Result<(), WsError> {
40        self.inner
41            .send(Message::Text(text.into()))
42            .await
43            .map_err(|e| WsError::Send(e.to_string()))
44    }
45
46    /// Send a close frame
47    pub async fn close(&mut self) -> Result<(), WsError> {
48        self.inner
49            .send(Message::Close(None))
50            .await
51            .map_err(|e| WsError::Send(e.to_string()))
52    }
53}
54
55impl WsReceiver {
56    /// Receive the next text message. Returns `None` when the connection is closed.
57    /// Non-text messages are silently skipped.
58    pub async fn recv(&mut self) -> Option<Result<String, WsError>> {
59        loop {
60            match self.inner.next().await {
61                Some(Ok(Message::Text(text))) => return Some(Ok(text.to_string())),
62                Some(Ok(Message::Close(_))) | None => return None,
63                Some(Ok(_)) => continue, // skip binary, ping, pong
64                Some(Err(e)) => return Some(Err(WsError::Receive(e.to_string()))),
65            }
66        }
67    }
68}
69
70/// Connect to a WebSocket endpoint with optional headers.
71///
72/// `headers` is a slice of `(name, value)` pairs to include in the upgrade request.
73pub async fn connect(
74    url: &str,
75    headers: &[(&str, &str)],
76) -> Result<(WsSender, WsReceiver), WsError> {
77    let mut request = url
78        .into_client_request()
79        .map_err(|e| WsError::Connection(e.to_string()))?;
80
81    for &(name, value) in headers {
82        if let (Ok(header_name), Ok(header_value)) = (
83            name.parse::<tokio_tungstenite::tungstenite::http::header::HeaderName>(),
84            value.parse::<tokio_tungstenite::tungstenite::http::header::HeaderValue>(),
85        ) {
86            request.headers_mut().insert(header_name, header_value);
87        }
88    }
89
90    let (ws_stream, _) = tokio_tungstenite::connect_async(request)
91        .await
92        .map_err(|e| WsError::Connection(e.to_string()))?;
93
94    let (sink, stream) = ws_stream.split();
95
96    Ok((
97        WsSender {
98            inner: Box::new(sink),
99        },
100        WsReceiver {
101            inner: Box::new(stream),
102        },
103    ))
104}