use std::collections::HashMap;
use http_body_util::BodyExt;
use tokio::sync::mpsc;
use crate::error::{ClientError, ClientResult};
use crate::streaming::EventStream;
use super::RestTransport;
impl RestTransport {
pub(super) async fn execute_streaming_request(
&self,
method: &str,
params: serde_json::Value,
extra_headers: &HashMap<String, String>,
) -> ClientResult<EventStream> {
trace_info!(method, base_url = %self.inner.base_url, "opening REST SSE stream");
let req = self.build_request(method, ¶ms, extra_headers, true)?;
let resp = tokio::time::timeout(
self.inner.stream_connect_timeout,
self.inner.client.request(req),
)
.await
.map_err(|_| {
trace_error!(method, "stream connect timed out");
ClientError::Timeout("stream connect timed out".into())
})?
.map_err(|e| {
trace_error!(method, error = %e, "HTTP client error");
ClientError::HttpClient(e.to_string())
})?;
let status = resp.status();
if !status.is_success() {
let body_bytes =
tokio::time::timeout(self.inner.stream_connect_timeout, resp.collect())
.await
.map_err(|_| ClientError::Timeout("error body read timed out".into()))?
.map_err(ClientError::Http)?
.to_bytes();
let body_str = String::from_utf8_lossy(&body_bytes);
return Err(ClientError::UnexpectedStatus {
status: status.as_u16(),
body: super::super::truncate_body(&body_str),
});
}
let actual_status = status.as_u16();
let (tx, rx) = mpsc::channel::<crate::streaming::event_stream::BodyChunk>(64);
let body = resp.into_body();
let task_handle = tokio::spawn(async move {
body_reader_task(body, tx).await;
});
Ok(EventStream::with_status(
rx,
task_handle.abort_handle(),
actual_status,
))
}
}
async fn body_reader_task(
mut body: hyper::body::Incoming,
tx: mpsc::Sender<crate::streaming::event_stream::BodyChunk>,
) {
loop {
match body.frame().await {
None => break,
Some(Err(e)) => {
let _ = tx.send(Err(ClientError::Http(e))).await;
break;
}
Some(Ok(f)) => {
if let Ok(data) = f.into_data() {
if tx.send(Ok(data)).await.is_err() {
break;
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use http_body_util::Full;
use hyper::body::Bytes;
use super::super::*;
#[tokio::test]
async fn execute_streaming_request_non_success_returns_error() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
loop {
let (stream, _) = listener.accept().await.unwrap();
let io = hyper_util::rt::TokioIo::new(stream);
tokio::spawn(async move {
let service = hyper::service::service_fn(|_req| async {
Ok::<_, hyper::Error>(
hyper::Response::builder()
.status(500)
.header("content-type", "text/plain")
.body(Full::new(Bytes::from("Internal Server Error")))
.unwrap(),
)
});
let _ = hyper_util::server::conn::auto::Builder::new(
hyper_util::rt::TokioExecutor::new(),
)
.serve_connection(io, service)
.await;
});
}
});
let url = format!("http://127.0.0.1:{}", addr.port());
let transport = RestTransport::new(&url).unwrap();
let result = transport
.execute_streaming_request(
"SendStreamingMessage",
serde_json::json!({}),
&HashMap::new(),
)
.await;
match result {
Err(ClientError::UnexpectedStatus { status, .. }) => {
assert_eq!(status, 500);
}
other => panic!("expected UnexpectedStatus, got {other:?}"),
}
}
#[tokio::test]
async fn execute_streaming_request_success_returns_event_stream() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
loop {
let (stream, _) = listener.accept().await.unwrap();
let io = hyper_util::rt::TokioIo::new(stream);
tokio::spawn(async move {
let service = hyper::service::service_fn(|_req| async {
let sse_body = "data: {\"jsonrpc\":\"2.0\",\"id\":\"1\",\"result\":{\"status\":\"ok\"}}\n\n";
Ok::<_, hyper::Error>(
hyper::Response::builder()
.status(200)
.header("content-type", "text/event-stream")
.body(Full::new(Bytes::from(sse_body)))
.unwrap(),
)
});
let _ = hyper_util::server::conn::auto::Builder::new(
hyper_util::rt::TokioExecutor::new(),
)
.serve_connection(io, service)
.await;
});
}
});
let url = format!("http://127.0.0.1:{}", addr.port());
let transport = RestTransport::new(&url).unwrap();
let mut stream = transport
.execute_streaming_request(
"SendStreamingMessage",
serde_json::json!({}),
&HashMap::new(),
)
.await
.unwrap();
let event = tokio::time::timeout(std::time::Duration::from_secs(5), stream.next())
.await
.expect("timed out waiting for event");
assert!(
event.is_some(),
"expected at least one event from the stream"
);
}
}