use adler32::adler32;
use async_stream::try_stream;
use async_trait::async_trait;
use base64::{engine::general_purpose, Engine as _};
use bytes::{BufMut, Bytes, BytesMut};
use futures_util::{Stream, StreamExt};
use http::HeaderMap;
use native_tls::TlsConnector;
use reqwest::{Client, ClientBuilder, Response};
use std::fmt::Debug;
use std::time::SystemTime;
use std::{pin::Pin, sync::Arc};
use tokio::sync::RwLock;
use url::Url;
use crate::asynchronous::generator::StreamGenerator;
use crate::{asynchronous::transport::AsyncTransport, error::Result, Error};
#[derive(Clone)]
pub struct PollingTransport {
client: Client,
base_url: Arc<RwLock<Url>>,
generator: StreamGenerator<Bytes>,
}
impl PollingTransport {
pub fn new(
base_url: Url,
tls_config: Option<TlsConnector>,
opening_headers: Option<HeaderMap>,
) -> Self {
let client = match (tls_config, opening_headers) {
(Some(config), Some(map)) => ClientBuilder::new()
.use_preconfigured_tls(config)
.default_headers(map)
.build()
.unwrap(),
(Some(config), None) => ClientBuilder::new()
.use_preconfigured_tls(config)
.build()
.unwrap(),
(None, Some(map)) => ClientBuilder::new().default_headers(map).build().unwrap(),
(None, None) => Client::new(),
};
let mut url = base_url;
url.query_pairs_mut().append_pair("transport", "polling");
PollingTransport {
client: client.clone(),
base_url: Arc::new(RwLock::new(url.clone())),
generator: StreamGenerator::new(Self::stream(url, client)),
}
}
fn address(mut url: Url) -> Result<Url> {
let reader = format!("{:#?}", SystemTime::now());
let hash = adler32(reader.as_bytes()).unwrap();
url.query_pairs_mut().append_pair("t", &hash.to_string());
Ok(url)
}
fn send_request(url: Url, client: Client) -> impl Stream<Item = Result<Response>> {
try_stream! {
let address = Self::address(url);
let response = client
.get(address?)
.send().await?;
let status = response.status().as_u16();
if status != 200 {
let err = match response.text().await {
Ok(body) => Error::HttpErrorWithBody { status, body },
Err(_) => Error::IncompleteHttp(status),
};
Err(err)?;
unreachable!();
}
yield response
}
}
fn stream(
url: Url,
client: Client,
) -> Pin<Box<dyn Stream<Item = Result<Bytes>> + 'static + Send>> {
Box::pin(try_stream! {
loop {
for await elem in Self::send_request(url.clone(), client.clone()) {
for await bytes in elem?.bytes_stream() {
yield bytes?;
}
}
}
})
}
}
impl Stream for PollingTransport {
type Item = Result<Bytes>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.generator.poll_next_unpin(cx)
}
}
#[async_trait]
impl AsyncTransport for PollingTransport {
async fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> {
let data_to_send = if is_binary_att {
let mut packet_bytes = BytesMut::with_capacity(data.len() + 1);
packet_bytes.put_u8(b'b');
let encoded_data = general_purpose::STANDARD.encode(data);
packet_bytes.put(encoded_data.as_bytes());
packet_bytes.freeze()
} else {
data
};
let response = self
.client
.post(self.address().await?)
.body(data_to_send)
.send()
.await?;
let status = response.status().as_u16();
if status != 200 {
return Err(match response.text().await {
Ok(body) => Error::HttpErrorWithBody { status, body },
Err(_) => Error::IncompleteHttp(status),
});
}
Ok(())
}
async fn base_url(&self) -> Result<Url> {
Ok(self.base_url.read().await.clone())
}
async fn set_base_url(&self, base_url: Url) -> Result<()> {
let mut url = base_url;
if !url
.query_pairs()
.any(|(k, v)| k == "transport" && v == "polling")
{
url.query_pairs_mut().append_pair("transport", "polling");
}
*self.base_url.write().await = url;
Ok(())
}
}
impl Debug for PollingTransport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PollingTransport")
.field("client", &self.client)
.field("base_url", &self.base_url)
.finish()
}
}
#[cfg(test)]
mod test {
use crate::asynchronous::transport::AsyncTransport;
use super::*;
use bytes::Bytes;
use futures_util::StreamExt;
use std::str::FromStr;
#[tokio::test]
async fn polling_transport_emit_returns_http_error_with_body() {
let body = r#"{"code":4008,"message":"Protocol version mismatch"}"#;
let url = crate::test::spawn_http_error_mock(400, body);
let transport = PollingTransport::new(url, None, None);
let err = transport
.emit(Bytes::from_static(b"hello"), false)
.await
.expect_err("emit should fail when server returns 400");
match err {
Error::HttpErrorWithBody { status, body: got } => {
assert_eq!(status, 400);
assert_eq!(got, body);
}
other => panic!("expected HttpErrorWithBody, got: {other:?}"),
}
}
#[tokio::test]
async fn polling_transport_get_returns_http_error_with_body() {
let body = r#"{"code":4008,"message":"Protocol version mismatch"}"#;
let url = crate::test::spawn_http_error_mock(400, body);
let mut transport = PollingTransport::new(url, None, None);
let err = transport
.next()
.await
.expect("stream should yield an item")
.expect_err("GET should fail when server returns 400");
match err {
Error::HttpErrorWithBody { status, body: got } => {
assert_eq!(status, 400);
assert_eq!(got, body);
}
other => panic!("expected HttpErrorWithBody, got: {other:?}"),
}
}
#[tokio::test]
async fn polling_transport_base_url() -> Result<()> {
let url = crate::test::engine_io_server()?.to_string();
let transport = PollingTransport::new(Url::from_str(&url[..]).unwrap(), None, None);
assert_eq!(
transport.base_url().await?.to_string(),
url.clone() + "?transport=polling"
);
transport
.set_base_url(Url::parse("https://127.0.0.1")?)
.await?;
assert_eq!(
transport.base_url().await?.to_string(),
"https://127.0.0.1/?transport=polling"
);
assert_ne!(transport.base_url().await?.to_string(), url);
transport
.set_base_url(Url::parse("http://127.0.0.1/?transport=polling")?)
.await?;
assert_eq!(
transport.base_url().await?.to_string(),
"http://127.0.0.1/?transport=polling"
);
assert_ne!(transport.base_url().await?.to_string(), url);
Ok(())
}
}