use std::io::{BufRead, Read};
use std::sync::Arc;
use http_body_util::BodyExt;
use hyper::body::Incoming;
use hyper::{Request, Response};
use hyper_util::rt::TokioIo;
use rustls_pki_types::ServerName;
use tokio::select;
use tokio::{
net::TcpStream,
sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
};
use tokio_rustls::rustls::{ClientConfig, RootCertStore};
use tokio_rustls::TlsConnector;
const TIMEOUT_MS: u64 = 10000;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Io error")]
IoError(#[from] std::io::Error),
#[error("Hyper error")]
HyperError(#[from] hyper::Error),
#[error("Http error")]
HttpError(#[from] hyper::http::Error),
#[error("Json error")]
JsonError(#[from] serde_json::Error),
}
type Result<T> = std::result::Result<T, Error>;
pub(crate) struct SseClient {
_join_handle: tokio::task::JoinHandle<()>,
shutdown: Option<tokio::sync::oneshot::Sender<()>>,
rx: UnboundedReceiver<Result<SseValue>>,
}
#[derive(Debug)]
pub(crate) struct SseValue {
pub(crate) event: String,
pub(crate) value: serde_json::Value,
}
async fn receive_events(
mut res: Response<Incoming>,
tx: UnboundedSender<Result<SseValue>>,
) -> Result<()> {
let mut accumulation = Vec::new();
while let Some(next) = res.frame().await {
let frame = next?;
if let Some(chunk) = frame.data_ref() {
let mut chunk = &**chunk;
tracing::debug!("Received chunk: `{}`", String::from_utf8_lossy(chunk));
let mut i = 0;
while !chunk.is_empty() && i < chunk.len() - 1 {
if chunk[i] == b'\n' && chunk[i + 1] == b'\n' {
let (message_end, tail) = chunk.split_at(i);
chunk = &tail[2..];
i = 0;
let mut message = std::io::Read::chain(
std::io::Cursor::new(&accumulation),
std::io::Cursor::new(message_end),
);
let mut staging = String::new();
let mut data = None;
let mut event = String::new();
loop {
let mut header = [0u8; 4];
if message.read_exact(&mut header).is_err() {
break;
}
match &header {
b"data" => {
let mut header_colon = [0u8; 2];
message.read_exact(&mut header_colon)?;
assert_eq!(&header_colon, b": ");
let mut data_line = String::new();
message.read_line(&mut data_line)?;
if data_line.ends_with('\n') {
data_line.pop(); }
data = Some(data_line);
}
b"even" => {
let mut header_colon = [0u8; 3];
message.read_exact(&mut header_colon)?;
assert_eq!(&header_colon, b"t: ");
message.read_line(&mut event)?;
if event.ends_with('\n') {
event.pop(); }
}
_ => {
message.read_line(&mut staging)?;
}
}
}
let Some(data) = data.take() else {
continue;
};
let value = serde_json::from_str(&data)?;
if let Err(_) = tx.send(Ok(SseValue { event, value })) {
tracing::error!("stream disconnected prematurely");
return Ok(());
}
accumulation.clear();
} else {
i += 1;
}
}
accumulation.extend_from_slice(chunk);
}
}
Ok(())
}
async fn run_client(
request: Request<String>,
tx: UnboundedSender<Result<SseValue>>,
shutdown_signal: tokio::sync::oneshot::Receiver<()>,
) -> Result<()> {
let url = request.uri();
let host = url.host().expect("Url should have a host");
let port = url.port_u16().unwrap_or(443);
let mut root_cert_store = RootCertStore::empty();
root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let mut config = ClientConfig::builder()
.with_root_certificates(root_cert_store)
.with_no_client_auth();
config.alpn_protocols = vec![b"h2".to_vec()];
let connector = TlsConnector::from(Arc::new(config));
let tls_domain = ServerName::try_from(host.to_string())
.map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid dnsname"))?;
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
let stream = connector.connect(tls_domain, stream).await?;
let executor = hyper_util::rt::tokio::TokioExecutor::new();
let io = TokioIo::new(stream);
let (mut sender, connection) = hyper::client::conn::http2::handshake(executor, io).await?;
tokio::task::spawn(async move {
if let Err(e) = connection.await {
tracing::error!("connection error: {}", e);
}
tracing::debug!("connection closed");
});
let work = sender.send_request(request);
let mut res =
match tokio::time::timeout(std::time::Duration::from_millis(TIMEOUT_MS), work).await {
Ok(result) => result?,
Err(_) => {
return Err(tokio::io::Error::new(tokio::io::ErrorKind::TimedOut, "Timeout").into())
}
};
let status = res.status();
if !status.is_success() {
let mut bytes = vec![];
while let Some(Ok(next)) = res.frame().await {
let frame = next;
if let Some(chunk) = frame.data_ref() {
let chunk = &**chunk;
bytes.extend_from_slice(chunk);
}
}
let body = String::from_utf8_lossy(&bytes);
return Err(tokio::io::Error::new(
tokio::io::ErrorKind::Other,
format!("request failed with status: {status} - `{body}`"),
)
.into());
}
tracing::debug!("sse opened successfully");
select! {
_ = receive_events(res, tx) => {
}
_ = shutdown_signal => {
}
};
Ok(())
}
impl SseClient {
pub(crate) fn spawn(request: Request<String>) -> Self {
let (tx, rx) = unbounded_channel();
let (shutdown, shutdown_signal) = tokio::sync::oneshot::channel::<()>();
let join_handle = tokio::spawn(async move {
let tx_clone = tx.clone();
if let Err(e) = run_client(request, tx_clone, shutdown_signal).await {
let _ = tx.send(Err(e));
}
});
Self {
_join_handle: join_handle,
rx,
shutdown: Some(shutdown),
}
}
}
impl futures::Stream for SseClient {
type Item = Result<SseValue>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.rx.poll_recv(cx)
}
}
impl Drop for SseClient {
fn drop(&mut self) {
if let Some(shutdown) = self.shutdown.take() {
if !shutdown.is_closed() {
shutdown.send(()).ok();
}
}
}
}