use bytes::Bytes;
use futures_lite::Stream;
use futures_lite::StreamExt;
use futures_lite::stream;
use serde::de::DeserializeOwned;
use std::pin::Pin;
use std::sync::{Arc, OnceLock};
use crate::body::Body;
use crate::error::{Error, ErrorKind, Result};
use crate::header::HeaderMap;
use crate::metrics::Metrics;
use crate::url::Url;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum Version {
Http11,
Http2,
Http3,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct StatusCode(u16);
impl StatusCode {
pub const CONTINUE: Self = Self(100);
pub const SWITCHING_PROTOCOLS: Self = Self(101);
pub const OK: Self = Self(200);
pub const CREATED: Self = Self(201);
pub const ACCEPTED: Self = Self(202);
pub const NO_CONTENT: Self = Self(204);
pub const PARTIAL_CONTENT: Self = Self(206);
pub const MOVED_PERMANENTLY: Self = Self(301);
pub const FOUND: Self = Self(302);
pub const SEE_OTHER: Self = Self(303);
pub const NOT_MODIFIED: Self = Self(304);
pub const TEMPORARY_REDIRECT: Self = Self(307);
pub const PERMANENT_REDIRECT: Self = Self(308);
pub const BAD_REQUEST: Self = Self(400);
pub const UNAUTHORIZED: Self = Self(401);
pub const FORBIDDEN: Self = Self(403);
pub const NOT_FOUND: Self = Self(404);
pub const METHOD_NOT_ALLOWED: Self = Self(405);
pub const CONFLICT: Self = Self(409);
pub const GONE: Self = Self(410);
pub const UNPROCESSABLE_ENTITY: Self = Self(422);
pub const TOO_MANY_REQUESTS: Self = Self(429);
pub const INTERNAL_SERVER_ERROR: Self = Self(500);
pub const NOT_IMPLEMENTED: Self = Self(501);
pub const BAD_GATEWAY: Self = Self(502);
pub const SERVICE_UNAVAILABLE: Self = Self(503);
pub const GATEWAY_TIMEOUT: Self = Self(504);
pub fn new(value: u16) -> Self {
Self(value)
}
pub fn as_u16(self) -> u16 {
self.0
}
pub fn is_informational(self) -> bool {
(100..200).contains(&self.0)
}
pub fn is_success(self) -> bool {
(200..300).contains(&self.0)
}
pub fn is_redirect(self) -> bool {
matches!(self.0, 301 | 302 | 303 | 307 | 308)
}
pub fn is_client_error(self) -> bool {
(400..500).contains(&self.0)
}
pub fn is_server_error(self) -> bool {
(500..600).contains(&self.0)
}
}
#[derive(Debug)]
pub struct Response {
status: StatusCode,
version: Version,
url: Url,
headers: HeaderMap,
trailers: TrailerState,
body: Body,
metrics: Metrics,
}
pub type BytesStream = Pin<Box<dyn Stream<Item = Result<Bytes>> + Send + 'static>>;
pub type SseStream = Pin<Box<dyn Stream<Item = Result<SseEvent>> + Send + 'static>>;
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct SseEvent {
pub event: String,
pub data: String,
pub id: Option<String>,
pub retry: Option<std::time::Duration>,
}
impl Response {
pub fn new(
status: StatusCode,
version: Version,
url: Url,
headers: HeaderMap,
trailers: Option<HeaderMap>,
body: Body,
) -> Self {
Self {
status,
version,
url,
headers,
trailers: TrailerState::Ready(trailers),
body,
metrics: Metrics::default(),
}
}
pub(crate) fn new_with_trailer_state(
status: StatusCode,
version: Version,
url: Url,
headers: HeaderMap,
trailers: TrailerState,
body: Body,
) -> Self {
Self {
status,
version,
url,
headers,
trailers,
body,
metrics: Metrics::default(),
}
}
pub fn status(&self) -> StatusCode {
self.status
}
pub fn version(&self) -> Version {
self.version
}
pub fn url(&self) -> &Url {
&self.url
}
pub fn headers(&self) -> &HeaderMap {
&self.headers
}
pub fn metrics(&self) -> &Metrics {
&self.metrics
}
pub fn cookies(&self) -> Vec<&str> {
self.headers.get_all("set-cookie")
}
pub fn headers_mut(&mut self) -> &mut HeaderMap {
&mut self.headers
}
pub fn trailers(&self) -> Option<&HeaderMap> {
match &self.trailers {
TrailerState::Ready(trailers) => trailers.as_ref(),
TrailerState::Deferred(trailers) => trailers.get(),
}
}
pub fn error_for_status(self) -> Result<Self> {
if self.status.is_client_error() || self.status.is_server_error() {
return Err(Error::new(
ErrorKind::Transport,
format!("unexpected response status: {}", self.status.as_u16()),
));
}
Ok(self)
}
pub async fn bytes(self) -> Result<Bytes> {
Ok(self.bytes_and_trailers().await?.0)
}
pub async fn bytes_and_trailers(mut self) -> Result<(Bytes, Option<HeaderMap>)> {
let bytes = if let Ok(bytes) = self.body.take_bytes() {
bytes
} else {
let mut stream = self.body.take_stream()?;
let mut data = Vec::new();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
data.extend_from_slice(&chunk);
}
Bytes::from(data)
};
Ok((bytes, self.take_trailers()))
}
pub async fn text(self) -> Result<String> {
let bytes = self.bytes().await?;
String::from_utf8(bytes.to_vec()).map_err(|err| {
Error::with_source(ErrorKind::Decode, "response body is not valid utf-8", err)
})
}
pub async fn json<T: DeserializeOwned>(self) -> Result<T> {
let bytes = self.bytes().await?;
serde_json::from_slice(&bytes).map_err(|err| {
Error::with_source(ErrorKind::Decode, "failed to decode response json", err)
})
}
pub fn bytes_stream(self) -> BytesStream {
self.into_body_stream_and_trailer_state().0
}
pub(crate) fn into_body_stream_and_trailer_state(mut self) -> (BytesStream, TrailerState) {
let body: BytesStream = if let Ok(stream_body) = self.body.take_stream() {
Box::pin(stream_body)
} else {
match self.body.take_bytes() {
Ok(bytes) => Box::pin(stream::once(Ok(bytes))),
Err(err) => Box::pin(stream::once(Err(err))),
}
};
let trailers = std::mem::replace(&mut self.trailers, TrailerState::Ready(None));
(body, trailers)
}
pub fn sse(self) -> Result<SseStream> {
if let Some(content_type) = self.headers().get("content-type") {
let mime = content_type
.split(';')
.next()
.map(str::trim)
.unwrap_or(content_type);
if !mime.eq_ignore_ascii_case("text/event-stream") {
return Err(Error::new(
ErrorKind::Decode,
format!("response is not an SSE stream: {content_type}"),
));
}
}
let (mut body, _trailers) = self.into_body_stream_and_trailer_state();
let events = async_stream::try_stream! {
let mut parser = SseParser::default();
while let Some(chunk) = body.next().await {
let chunk = chunk?;
parser.push(&chunk);
while let Some(event) = parser.next_event()? {
yield event;
}
}
if let Some(event) = parser.finish()? {
yield event;
}
};
Ok(Box::pin(events))
}
pub(crate) fn with_metrics(mut self, metrics: Metrics) -> Self {
self.metrics = metrics;
self
}
}
#[derive(Debug)]
pub(crate) enum TrailerState {
Ready(Option<HeaderMap>),
Deferred(Arc<OnceLock<HeaderMap>>),
}
impl TrailerState {
pub(crate) fn take(self) -> Option<HeaderMap> {
match self {
TrailerState::Ready(trailers) => trailers,
TrailerState::Deferred(trailers) => trailers.get().cloned(),
}
}
}
impl Response {
fn take_trailers(self) -> Option<HeaderMap> {
self.trailers.take()
}
}
#[derive(Default)]
struct SseParser {
buffer: Vec<u8>,
data_lines: Vec<String>,
event_type: Option<String>,
last_event_id: Option<String>,
retry: Option<std::time::Duration>,
bom_checked: bool,
}
impl SseParser {
fn push(&mut self, chunk: &[u8]) {
self.buffer.extend_from_slice(chunk);
}
fn next_event(&mut self) -> Result<Option<SseEvent>> {
loop {
let Some(line_end) = self.buffer.iter().position(|byte| *byte == b'\n') else {
return Ok(None);
};
let mut line = self.buffer.drain(..=line_end).collect::<Vec<_>>();
if line.last() == Some(&b'\n') {
line.pop();
}
if line.last() == Some(&b'\r') {
line.pop();
}
if let Some(event) = self.process_line(&line)? {
return Ok(Some(event));
}
}
}
fn finish(&mut self) -> Result<Option<SseEvent>> {
if !self.buffer.is_empty() {
let line = std::mem::take(&mut self.buffer);
if let Some(event) = self.process_line(&line)? {
return Ok(Some(event));
}
}
Ok(self.dispatch_event())
}
fn process_line(&mut self, line: &[u8]) -> Result<Option<SseEvent>> {
let mut line = std::str::from_utf8(line).map_err(|err| {
Error::with_source(ErrorKind::Decode, "response body is not valid utf-8", err)
})?;
if !self.bom_checked {
self.bom_checked = true;
line = line.strip_prefix('\u{feff}').unwrap_or(line);
}
if line.is_empty() {
return Ok(self.dispatch_event());
}
if line.starts_with(':') {
return Ok(None);
}
let (field, value) = match line.split_once(':') {
Some((field, value)) => (field, value.strip_prefix(' ').unwrap_or(value)),
None => (line, ""),
};
match field {
"data" => self.data_lines.push(value.to_owned()),
"event" => self.event_type = Some(value.to_owned()),
"id" => {
if !value.contains('\0') {
self.last_event_id = Some(value.to_owned());
}
}
"retry" => {
if let Ok(millis) = value.parse::<u64>() {
self.retry = Some(std::time::Duration::from_millis(millis));
}
}
_ => {}
}
Ok(None)
}
fn dispatch_event(&mut self) -> Option<SseEvent> {
let data_lines = std::mem::take(&mut self.data_lines);
let event = self.event_type.take();
if data_lines.is_empty() {
return None;
}
Some(SseEvent {
event: event
.filter(|v| !v.is_empty())
.unwrap_or_else(|| "message".into()),
data: data_lines.join("\n"),
id: self.last_event_id.clone().filter(|value| !value.is_empty()),
retry: self.retry.take(),
})
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use futures_lite::{StreamExt, future::block_on, stream};
use super::{Response, SseEvent, StatusCode, Version};
use crate::{Body, HeaderMap, Url};
#[test]
fn response_text_consumes_body() {
let response = Response::new(
StatusCode::OK,
Version::Http11,
Url::parse("https://api.example.com").unwrap(),
HeaderMap::new(),
None,
Body::from("hello"),
);
let text = block_on(response.text()).unwrap();
assert_eq!(text, "hello");
}
#[test]
fn response_json_decodes_payload() {
let response = Response::new(
StatusCode::OK,
Version::Http11,
Url::parse("https://api.example.com").unwrap(),
HeaderMap::new(),
None,
Body::from(Bytes::from_static(br#"{"ok":true}"#)),
);
let json: serde_json::Value = block_on(response.json()).unwrap();
assert_eq!(json["ok"], true);
}
#[test]
fn response_error_for_status_returns_error_on_4xx() {
let response = Response::new(
StatusCode::new(404),
Version::Http11,
Url::parse("https://api.example.com").unwrap(),
HeaderMap::new(),
None,
Body::default(),
);
assert!(response.error_for_status().is_err());
}
#[test]
fn response_cookies_reads_set_cookie_headers() {
let mut headers = HeaderMap::new();
headers.append("set-cookie", "a=1; Path=/").unwrap();
headers.append("set-cookie", "b=2; Path=/").unwrap();
let response = Response::new(
StatusCode::OK,
Version::Http11,
Url::parse("https://api.example.com").unwrap(),
headers,
None,
Body::default(),
);
assert_eq!(response.cookies(), vec!["a=1; Path=/", "b=2; Path=/"]);
}
#[test]
fn response_bytes_stream_yields_body_once() {
let response = Response::new(
StatusCode::OK,
Version::Http11,
Url::parse("https://api.example.com").unwrap(),
HeaderMap::new(),
None,
Body::from("hello"),
);
let mut stream = response.bytes_stream();
let chunk = block_on(stream.next()).unwrap().unwrap();
assert_eq!(chunk, Bytes::from_static(b"hello"));
assert!(block_on(stream.next()).is_none());
}
#[test]
fn response_bytes_stream_uses_stream_body() {
let stream_body = stream::iter(vec![
Ok(Bytes::from_static(b"hello")),
Ok(Bytes::from_static(b"world")),
]);
let response = Response::new(
StatusCode::OK,
Version::Http11,
Url::parse("https://example.com").unwrap(),
HeaderMap::new(),
None,
Body::from_stream(Box::pin(stream_body)),
);
let chunks = block_on(response.bytes_stream().collect::<Vec<_>>());
assert_eq!(chunks.len(), 2);
assert_eq!(chunks[0].as_ref().unwrap(), &Bytes::from_static(b"hello"));
assert_eq!(chunks[1].as_ref().unwrap(), &Bytes::from_static(b"world"));
}
#[test]
fn response_sse_parses_events() {
let mut headers = HeaderMap::new();
headers
.insert("content-type", "text/event-stream; charset=utf-8")
.unwrap();
let response = Response::new(
StatusCode::OK,
Version::Http11,
Url::parse("https://api.example.com").unwrap(),
headers,
None,
Body::from("id: 1\nevent: message\nretry: 1500\ndata: hello\ndata: world\n\n"),
);
let events = block_on(async {
let stream = response.sse().unwrap();
stream.collect::<Vec<_>>().await
});
assert_eq!(events.len(), 1);
assert_eq!(
events[0].as_ref().unwrap(),
&SseEvent {
id: Some("1".into()),
event: "message".into(),
data: "hello\nworld".into(),
retry: Some(std::time::Duration::from_millis(1500)),
}
);
}
#[test]
fn response_sse_streams_events_across_chunk_boundaries() {
let mut headers = HeaderMap::new();
headers.insert("content-type", "text/event-stream").unwrap();
let stream_body = stream::iter(vec![
Ok(Bytes::from_static(b"id: 1\r\n")),
Ok(Bytes::from_static(b"event: update\r\n")),
Ok(Bytes::from_static(b"data: hello\r\n")),
Ok(Bytes::from_static(b"data:\r\n")),
Ok(Bytes::from_static(b"data: world\r")),
Ok(Bytes::from_static(b"\n\r\n: keepalive\r\nid: 2\r\n\r\n")),
Ok(Bytes::from_static(b"data: next\r\n\r\n")),
]);
let response = Response::new(
StatusCode::OK,
Version::Http11,
Url::parse("https://example.com/events").unwrap(),
headers,
None,
Body::from_stream(Box::pin(stream_body)),
);
let events = block_on(async {
let stream = response.sse().unwrap();
stream.collect::<Vec<_>>().await
});
assert_eq!(events.len(), 2);
assert_eq!(
events[0].as_ref().unwrap(),
&SseEvent {
id: Some("1".into()),
event: "update".into(),
data: "hello\n\nworld".into(),
retry: None,
}
);
assert_eq!(
events[1].as_ref().unwrap(),
&SseEvent {
id: Some("2".into()),
event: "message".into(),
data: "next".into(),
retry: None,
}
);
}
#[test]
fn response_sse_rejects_non_sse_content_type() {
let mut headers = HeaderMap::new();
headers.insert("content-type", "application/json").unwrap();
let response = Response::new(
StatusCode::OK,
Version::Http11,
Url::parse("https://api.example.com").unwrap(),
headers,
None,
Body::from("{}"),
);
let err = response.sse().err().unwrap();
assert_eq!(err.kind(), &crate::ErrorKind::Decode);
assert!(err.to_string().contains("response is not an SSE stream"));
}
}