engineio_rs/socket/
builder.rs

1use futures_util::StreamExt;
2use reqwest::Url;
3use tracing::trace;
4
5use crate::{
6    error::Result,
7    header::HeaderMap,
8    packet::HandshakePacket,
9    socket::Socket,
10    transports::{
11        polling::ClientPollingTransport, websocket::WebsocketTransport, Transport, TransportType,
12    },
13    Error, Packet, ENGINE_IO_VERSION,
14};
15
16#[derive(Clone, Debug)]
17pub struct SocketBuilder {
18    url: Url,
19    should_pong: bool,
20    headers: Option<HeaderMap>,
21    handshake: Option<HandshakePacket>,
22    channel_size: usize,
23}
24
25impl SocketBuilder {
26    pub fn new(url: Url) -> Self {
27        let mut url = url;
28        url.query_pairs_mut()
29            .append_pair("EIO", &ENGINE_IO_VERSION.to_string());
30
31        // No path add engine.io
32        if url.path() == "/" {
33            url.set_path("/engine.io/");
34        }
35        SocketBuilder {
36            url,
37            headers: None,
38            should_pong: true,
39            handshake: None,
40            channel_size: 100,
41        }
42    }
43
44    pub fn headers(mut self, headers: HeaderMap) -> Self {
45        self.headers = Some(headers);
46        self
47    }
48
49    pub fn channel_buf(mut self, size: usize) -> Self {
50        self.channel_size = size;
51        self
52    }
53
54    async fn handshake_with_transport<T: Transport>(&mut self, transport: &mut T) -> Result<()> {
55        trace!("client handshake_with_transport {:?}", self.handshake);
56        // No need to handshake twice
57        if self.handshake.is_some() {
58            return Ok(());
59        }
60
61        let mut url = self.url.clone();
62
63        let handshake: HandshakePacket =
64            Packet::try_from(transport.next().await.ok_or(Error::IncompletePacket())??)?
65                .try_into()?;
66        trace!("handshake packet {:?}", handshake);
67
68        // update the base_url with the new sid
69        url.query_pairs_mut().append_pair("sid", &handshake.sid[..]);
70
71        self.handshake = Some(handshake);
72
73        self.url = url;
74
75        Ok(())
76    }
77
78    async fn handshake(&mut self) -> Result<()> {
79        trace!("client handshake");
80        if self.handshake.is_some() {
81            return Ok(());
82        }
83
84        let headers = if let Some(map) = self.headers.clone() {
85            Some(map.try_into()?)
86        } else {
87            None
88        };
89
90        // Start with polling transport
91        let mut transport = ClientPollingTransport::new(self.url.clone(), headers)?;
92
93        self.handshake_with_transport(&mut transport).await
94    }
95
96    /// Build websocket if allowed, if not fall back to polling
97    pub async fn build(mut self) -> Result<Socket> {
98        self.handshake().await?;
99
100        if self.websocket_upgrade()? {
101            self.build_websocket_with_upgrade().await
102        } else {
103            self.build_polling().await
104        }
105    }
106
107    /// Build websocket if allowed, if not allowed or errored fall back to polling.
108    /// WARNING: websocket errors suppressed, no indication of websocket success or failure.
109    pub async fn build_with_fallback(self) -> Result<Socket> {
110        let result = self.clone().build().await;
111        if result.is_err() {
112            self.build_polling().await
113        } else {
114            result
115        }
116    }
117
118    /// Checks the handshake to see if websocket upgrades are allowed
119    fn websocket_upgrade(&mut self) -> Result<bool> {
120        if self.handshake.is_none() {
121            return Ok(false);
122        }
123
124        Ok(self
125            .handshake
126            .as_ref()
127            .unwrap()
128            .upgrades
129            .iter()
130            .any(|upgrade| upgrade.to_lowercase() == *"websocket"))
131    }
132
133    /// Build socket with a polling transport then upgrade to websocket transport
134    pub async fn build_websocket_with_upgrade(mut self) -> Result<Socket> {
135        trace!("build_websocket_with_upgrade");
136        self.handshake().await?;
137
138        if self.websocket_upgrade()? {
139            self.build_websocket().await
140        } else {
141            Err(Error::IllegalWebsocketUpgrade())
142        }
143    }
144
145    /// Build socket with only a websocket transport
146    pub async fn build_websocket(mut self) -> Result<Socket> {
147        let headers = if let Some(map) = self.headers.clone() {
148            Some(map.try_into()?)
149        } else {
150            None
151        };
152
153        let (sender, receiver) = WebsocketTransport::connect(self.url.clone(), headers).await?;
154        let mut transport = WebsocketTransport::new(sender, receiver);
155
156        if self.handshake.is_some() {
157            transport.upgrade().await?;
158        } else {
159            self.handshake_with_transport(&mut transport).await?;
160        }
161
162        trace!("build_websocket success");
163
164        // NOTE: Although self.url contains the sid, it does not propagate to the transport
165        // SAFETY: handshake function called previously.
166        Ok(Socket::new(
167            TransportType::Websocket(transport),
168            self.handshake.unwrap(),
169            None,
170            self.should_pong,
171            false,
172        ))
173    }
174
175    pub async fn build_polling(mut self) -> Result<Socket> {
176        trace!("build_polling");
177        self.handshake().await?;
178
179        // Make a polling transport with new sid
180        // TODO: tls
181        let transport =
182            ClientPollingTransport::new(self.url, self.headers.map(|v| v.try_into().unwrap()))?;
183
184        // SAFETY: handshake function called previously.
185        Ok(Socket::new(
186            TransportType::ClientPolling(transport),
187            self.handshake.unwrap(),
188            None,
189            self.should_pong,
190            false,
191        ))
192    }
193
194    #[cfg(test)]
195    pub(crate) fn should_pong_for_test(mut self, should_pong: bool) -> Self {
196        self.should_pong = should_pong;
197        self
198    }
199}