use std::future::Future;
use std::io;
use std::{
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, ReadBuf};
use crate::http::CookieJar;
use crate::{Request, Response};
pub struct LocalResponse<'c> {
response: Response<'c>,
cookies: CookieJar<'c>,
_request: Box<Request<'c>>,
}
impl Drop for LocalResponse<'_> {
fn drop(&mut self) {}
}
impl<'c> LocalResponse<'c> {
pub(crate) fn new<F, O>(req: Request<'c>, f: F) -> impl Future<Output = LocalResponse<'c>>
where
F: FnOnce(&'c Request<'c>) -> O + Send,
O: Future<Output = Response<'c>> + Send,
{
let boxed_req = Box::new(req);
let request: &'c Request<'c> = unsafe { &*(&*boxed_req as *const _) };
async move {
let response: Response<'c> = f(request).await;
let mut cookies = CookieJar::new(None, request.rocket());
for cookie in response.cookies() {
cookies.add_original(cookie.into_owned());
}
LocalResponse {
_request: boxed_req,
cookies,
response,
}
}
}
}
impl LocalResponse<'_> {
pub(crate) fn _response(&self) -> &Response<'_> {
&self.response
}
pub(crate) fn _cookies(&self) -> &CookieJar<'_> {
&self.cookies
}
pub(crate) async fn _into_string(mut self) -> io::Result<String> {
self.response.body_mut().to_string().await
}
pub(crate) async fn _into_bytes(mut self) -> io::Result<Vec<u8>> {
self.response.body_mut().to_bytes().await
}
#[cfg(feature = "json")]
async fn _into_json<T>(self) -> Option<T>
where
T: Send + serde::de::DeserializeOwned + 'static,
{
self.blocking_read(|r| serde_json::from_reader(r))
.await?
.ok()
}
#[cfg(feature = "msgpack")]
async fn _into_msgpack<T>(self) -> Option<T>
where
T: Send + serde::de::DeserializeOwned + 'static,
{
self.blocking_read(|r| rmp_serde::from_read(r)).await?.ok()
}
#[cfg(any(feature = "json", feature = "msgpack"))]
async fn blocking_read<T, F>(mut self, f: F) -> Option<T>
where
T: Send + 'static,
F: FnOnce(&mut dyn io::Read) -> T + Send + 'static,
{
use tokio::io::AsyncReadExt;
use tokio::sync::mpsc;
struct ChanReader {
last: Option<io::Cursor<Vec<u8>>>,
rx: mpsc::Receiver<io::Result<Vec<u8>>>,
}
impl std::io::Read for ChanReader {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
loop {
if let Some(ref mut cursor) = self.last {
if cursor.position() < cursor.get_ref().len() as u64 {
return std::io::Read::read(cursor, buf);
}
}
if let Some(buf) = self.rx.blocking_recv() {
self.last = Some(io::Cursor::new(buf?));
} else {
return Ok(0);
}
}
}
}
let (tx, rx) = mpsc::channel(2);
let reader = tokio::task::spawn_blocking(move || {
let mut reader = ChanReader { last: None, rx };
f(&mut reader)
});
loop {
let mut buf = Vec::with_capacity(1024);
match self.read_buf(&mut buf).await {
Ok(0) => break,
Ok(_) => tx.send(Ok(buf)).await.ok()?,
Err(e) => {
tx.send(Err(e)).await.ok()?;
break;
}
}
}
drop(tx);
reader.await.ok()
}
pub_response_impl!("# use rkt::local::asynchronous::Client;\n\
use rkt::local::asynchronous::LocalResponse;" async await);
}
impl AsyncRead for LocalResponse<'_> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(self.response.body_mut()).poll_read(cx, buf)
}
}
impl std::fmt::Debug for LocalResponse<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self._response().fmt(f)
}
}