tf-rust-engineio 0.8.0

An Engine.IO client implementation in Rust. Fork of rust_engineio with ACK support and reconnect enhancements.
Documentation
use adler32::adler32;
use async_stream::try_stream;
use async_trait::async_trait;
use base64::{engine::general_purpose, Engine as _};
use bytes::{BufMut, Bytes, BytesMut};
use futures_util::{Stream, StreamExt};
use http::HeaderMap;
use native_tls::TlsConnector;
use reqwest::{Client, ClientBuilder, Response};
use std::fmt::Debug;
use std::time::SystemTime;
use std::{pin::Pin, sync::Arc};
use tokio::sync::RwLock;
use url::Url;

use crate::asynchronous::generator::StreamGenerator;
use crate::{asynchronous::transport::AsyncTransport, error::Result, Error};

/// An asynchronous polling type. Makes use of the nonblocking reqwest types and
/// methods.
#[derive(Clone)]
pub struct PollingTransport {
    client: Client,
    base_url: Arc<RwLock<Url>>,
    generator: StreamGenerator<Bytes>,
}

impl PollingTransport {
    pub fn new(
        base_url: Url,
        tls_config: Option<TlsConnector>,
        opening_headers: Option<HeaderMap>,
    ) -> Self {
        let client = match (tls_config, opening_headers) {
            (Some(config), Some(map)) => ClientBuilder::new()
                .use_preconfigured_tls(config)
                .default_headers(map)
                .build()
                .unwrap(),
            (Some(config), None) => ClientBuilder::new()
                .use_preconfigured_tls(config)
                .build()
                .unwrap(),
            (None, Some(map)) => ClientBuilder::new().default_headers(map).build().unwrap(),
            (None, None) => Client::new(),
        };

        let mut url = base_url;
        url.query_pairs_mut().append_pair("transport", "polling");

        PollingTransport {
            client: client.clone(),
            base_url: Arc::new(RwLock::new(url.clone())),
            generator: StreamGenerator::new(Self::stream(url, client)),
        }
    }

    fn address(mut url: Url) -> Result<Url> {
        let reader = format!("{:#?}", SystemTime::now());
        let hash = adler32(reader.as_bytes()).unwrap();
        url.query_pairs_mut().append_pair("t", &hash.to_string());
        Ok(url)
    }

    fn send_request(url: Url, client: Client) -> impl Stream<Item = Result<Response>> {
        try_stream! {
            let address = Self::address(url);

            let response = client
                .get(address?)
                .send().await?;

            let status = response.status().as_u16();
            if status != 200 {
                let err = match response.text().await {
                    Ok(body) => Error::HttpErrorWithBody { status, body },
                    Err(_) => Error::IncompleteHttp(status),
                };
                Err(err)?;
                unreachable!();
            }

            yield response
        }
    }

    fn stream(
        url: Url,
        client: Client,
    ) -> Pin<Box<dyn Stream<Item = Result<Bytes>> + 'static + Send>> {
        Box::pin(try_stream! {
            loop {
                for await elem in Self::send_request(url.clone(), client.clone()) {
                    for await bytes in elem?.bytes_stream() {
                        yield bytes?;
                    }
                }
            }
        })
    }
}

impl Stream for PollingTransport {
    type Item = Result<Bytes>;

    fn poll_next(
        mut self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Option<Self::Item>> {
        self.generator.poll_next_unpin(cx)
    }
}

#[async_trait]
impl AsyncTransport for PollingTransport {
    async fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> {
        let data_to_send = if is_binary_att {
            // the binary attachment gets `base64` encoded
            let mut packet_bytes = BytesMut::with_capacity(data.len() + 1);
            packet_bytes.put_u8(b'b');

            let encoded_data = general_purpose::STANDARD.encode(data);
            packet_bytes.put(encoded_data.as_bytes());

            packet_bytes.freeze()
        } else {
            data
        };

        let response = self
            .client
            .post(self.address().await?)
            .body(data_to_send)
            .send()
            .await?;

        let status = response.status().as_u16();
        if status != 200 {
            return Err(match response.text().await {
                Ok(body) => Error::HttpErrorWithBody { status, body },
                Err(_) => Error::IncompleteHttp(status),
            });
        }

        Ok(())
    }

    async fn base_url(&self) -> Result<Url> {
        Ok(self.base_url.read().await.clone())
    }

    async fn set_base_url(&self, base_url: Url) -> Result<()> {
        let mut url = base_url;
        if !url
            .query_pairs()
            .any(|(k, v)| k == "transport" && v == "polling")
        {
            url.query_pairs_mut().append_pair("transport", "polling");
        }
        *self.base_url.write().await = url;
        Ok(())
    }
}

impl Debug for PollingTransport {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("PollingTransport")
            .field("client", &self.client)
            .field("base_url", &self.base_url)
            .finish()
    }
}

#[cfg(test)]
mod test {
    use crate::asynchronous::transport::AsyncTransport;

    use super::*;
    use bytes::Bytes;
    use futures_util::StreamExt;
    use std::str::FromStr;

    #[tokio::test]
    async fn polling_transport_emit_returns_http_error_with_body() {
        let body = r#"{"code":4008,"message":"Protocol version mismatch"}"#;
        let url = crate::test::spawn_http_error_mock(400, body);
        let transport = PollingTransport::new(url, None, None);

        let err = transport
            .emit(Bytes::from_static(b"hello"), false)
            .await
            .expect_err("emit should fail when server returns 400");

        match err {
            Error::HttpErrorWithBody { status, body: got } => {
                assert_eq!(status, 400);
                assert_eq!(got, body);
            }
            other => panic!("expected HttpErrorWithBody, got: {other:?}"),
        }
    }

    #[tokio::test]
    async fn polling_transport_get_returns_http_error_with_body() {
        let body = r#"{"code":4008,"message":"Protocol version mismatch"}"#;
        let url = crate::test::spawn_http_error_mock(400, body);
        let mut transport = PollingTransport::new(url, None, None);

        let err = transport
            .next()
            .await
            .expect("stream should yield an item")
            .expect_err("GET should fail when server returns 400");

        match err {
            Error::HttpErrorWithBody { status, body: got } => {
                assert_eq!(status, 400);
                assert_eq!(got, body);
            }
            other => panic!("expected HttpErrorWithBody, got: {other:?}"),
        }
    }

    #[tokio::test]
    async fn polling_transport_base_url() -> Result<()> {
        let url = crate::test::engine_io_server()?.to_string();
        let transport = PollingTransport::new(Url::from_str(&url[..]).unwrap(), None, None);
        assert_eq!(
            transport.base_url().await?.to_string(),
            url.clone() + "?transport=polling"
        );
        transport
            .set_base_url(Url::parse("https://127.0.0.1")?)
            .await?;
        assert_eq!(
            transport.base_url().await?.to_string(),
            "https://127.0.0.1/?transport=polling"
        );
        assert_ne!(transport.base_url().await?.to_string(), url);

        transport
            .set_base_url(Url::parse("http://127.0.0.1/?transport=polling")?)
            .await?;
        assert_eq!(
            transport.base_url().await?.to_string(),
            "http://127.0.0.1/?transport=polling"
        );
        assert_ne!(transport.base_url().await?.to_string(), url);
        Ok(())
    }
}