use std::collections::HashSet;
use futures_util::{Stream, StreamExt};
use http::{HeaderValue, StatusCode, header, header::HeaderName};
use serde::{Serialize, de::DeserializeOwned};
use serde_json::Value;
use tokio_util::compat::TokioAsyncReadCompatExt;
use crate::{Response, test::json::TestJson, web::sse::Event};
pub struct TestResponse(pub Response);
impl TestResponse {
pub(crate) fn new(resp: Response) -> Self {
Self(resp)
}
#[track_caller]
pub fn assert_status(&self, status: StatusCode) {
assert_eq!(self.0.status(), status);
}
#[track_caller]
pub fn assert_status_is_ok(&self) {
self.assert_status(StatusCode::OK);
}
#[track_caller]
pub fn assert_header_is_not_exist<K>(&self, key: K)
where
K: TryInto<HeaderName>,
{
let key = key.try_into().map_err(|_| ()).expect("valid header name");
assert!(!self.0.headers().contains_key(key));
}
#[track_caller]
pub fn assert_header_exist<K>(&self, key: K)
where
K: TryInto<HeaderName>,
{
let key = key.try_into().map_err(|_| ()).expect("valid header name");
assert!(self.0.headers().contains_key(key));
}
#[track_caller]
pub fn assert_header<K, V>(&self, key: K, value: V)
where
K: TryInto<HeaderName>,
V: TryInto<HeaderValue>,
{
let key = key.try_into().map_err(|_| ()).expect("valid header name");
let value = value
.try_into()
.map_err(|_| ())
.expect("valid header value");
let value2 = self
.0
.headers()
.get(&key)
.unwrap_or_else(|| panic!("expect header `{key}`"));
assert_eq!(value2, value);
}
#[track_caller]
pub fn assert_header_csv<K, V, I>(&self, key: K, values: I)
where
K: TryInto<HeaderName>,
V: AsRef<str>,
I: IntoIterator<Item = V>,
{
let expect_values = values.into_iter().collect::<Vec<_>>();
let expect_values = expect_values
.iter()
.map(|value| value.as_ref())
.collect::<HashSet<_>>();
let key = key.try_into().map_err(|_| ()).expect("valid header name");
let value = self
.0
.headers()
.get(&key)
.unwrap_or_else(|| panic!("expect header `{key}`"));
let values = value
.to_str()
.expect("valid header value")
.split(',')
.map(|s| s.trim())
.collect::<HashSet<_>>();
assert_eq!(values, expect_values);
}
#[track_caller]
pub fn assert_header_all<K, V, I>(&self, key: K, values: I)
where
K: TryInto<HeaderName>,
V: TryInto<HeaderValue>,
I: IntoIterator<Item = V>,
{
let key = key.try_into().map_err(|_| ()).expect("valid header name");
let mut values = values
.into_iter()
.map(|value| {
value
.try_into()
.map_err(|_| ())
.expect("valid header value")
})
.collect::<Vec<_>>();
let mut values2 = self
.0
.headers()
.get_all(&key)
.iter()
.cloned()
.collect::<Vec<_>>();
values.sort();
values2.sort();
assert_eq!(values, values2);
}
#[track_caller]
pub fn assert_content_type(&self, content_type: &str) {
self.assert_header(header::CONTENT_TYPE, content_type);
}
pub async fn assert_text(self, text: impl AsRef<str>) {
assert_eq!(
self.0.into_body().into_string().await.expect("expect body"),
text.as_ref()
);
}
pub async fn assert_bytes(self, bytes: impl AsRef<[u8]>) {
assert_eq!(
self.0.into_body().into_vec().await.expect("expect body"),
bytes.as_ref()
);
}
pub async fn assert_json(self, json: impl Serialize) {
assert_eq!(
self.0
.into_body()
.into_json::<Value>()
.await
.expect("expect body"),
serde_json::to_value(json).expect("valid json")
);
}
#[cfg(feature = "xml")]
pub async fn assert_xml(self, xml: impl Serialize) {
assert_eq!(
self.0.into_body().into_string().await.expect("expect body"),
quick_xml::se::to_string(&xml).expect("valid xml")
);
}
#[cfg(feature = "yaml")]
pub async fn assert_yaml(self, yaml: impl Serialize) {
assert_eq!(
self.0.into_body().into_string().await.expect("expect body"),
serde_yaml::to_string(&yaml).expect("valid yaml")
);
}
pub async fn json(self) -> TestJson {
self.0
.into_body()
.into_json::<TestJson>()
.await
.expect("expect body")
}
pub fn sse_stream(self) -> impl Stream<Item = Event> + Send + Unpin + 'static {
self.assert_content_type("text/event-stream");
sse_codec::decode_stream(self.0.into_body().into_async_read().compat())
.map(|res| {
let event = res.expect("valid sse frame");
match event {
sse_codec::Event::Message { id, event, data } => Event::Message {
id: id.unwrap_or_default(),
event,
data,
},
sse_codec::Event::Retry { retry } => Event::Retry { retry },
}
})
.boxed()
}
pub fn typed_sse_stream<T: DeserializeOwned + 'static>(
self,
) -> impl Stream<Item = T> + Send + Unpin + 'static {
self.sse_stream()
.filter_map(|event| async move {
match event {
Event::Message { data, .. } => {
Some(serde_json::from_str::<T>(&data).expect("valid data"))
}
Event::Retry { .. } => None,
}
})
.boxed()
}
pub fn json_sse_stream(self) -> impl Stream<Item = TestJson> + Send + Unpin + 'static {
self.typed_sse_stream::<TestJson>()
}
}