use std::future::{poll_fn, Future};
use std::pin::{pin, Pin};
use std::sync::OnceLock;
use futures::future::{select, Either};
use nyquest_interface::client::ClientOptions;
use nyquest_interface::r#async::{AsyncClient, AsyncResponse, Request};
use nyquest_interface::Result as NyquestResult;
use tokio::runtime::{Handle, Runtime};
use crate::client::ReqwestClient;
use crate::error::{ReqwestBackendError, Result};
use crate::response::ReqwestResponse;
#[cfg(feature = "async-stream")]
mod stream;
#[derive(Clone)]
pub struct ReqwestAsyncClient {
inner: ReqwestClient,
}
impl ReqwestAsyncClient {
pub fn new(options: ClientOptions) -> NyquestResult<Self> {
let inner = ReqwestClient::new(options)?;
Ok(Self { inner })
}
}
impl AsyncClient for ReqwestAsyncClient {
type Response = ReqwestAsyncResponse;
fn describe(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ReqwestAsyncClient")
}
async fn request(&self, req: Request) -> NyquestResult<Self::Response> {
#[cfg(feature = "async-stream")]
let mut stream_task_collection = stream::StreamTaskCollection::default();
let request_builder = self.inner.request(req, |stream| {
#[cfg(feature = "async-stream")]
{
use nyquest_interface::r#async::BoxedStream;
let size = match &stream {
BoxedStream::Sized { content_length, .. } => Some(*content_length),
BoxedStream::Unsized { .. } => None,
};
let stream = stream_task_collection.add_stream(stream);
(reqwest::Body::wrap(stream), size)
}
#[cfg(not(feature = "async-stream"))]
{
let _ = stream;
unreachable!("async-stream feature is disabled")
}
})?;
let req_task = pin!(execute_with_runtime_async(
&self.inner.managed_runtime,
|| async {
request_builder
.send()
.await
.map_err(ReqwestBackendError::Reqwest)
}
));
#[cfg(not(feature = "async-stream"))]
let stream_task = std::future::pending::<()>();
#[cfg(feature = "async-stream")]
let stream_task = pin!(stream_task_collection.execute());
let (response, handle) = if let Either::Left((res, _)) = select(req_task, stream_task).await
{
res
} else {
unreachable!()
};
ReqwestAsyncResponse::new(response?, self.inner.max_response_buffer_size, handle)
.await
.map_err(Into::into)
}
}
fn create_managed_runtime() -> Runtime {
tokio::runtime::Builder::new_multi_thread()
.thread_name("nyquest-reqwest-async")
.worker_threads(1)
.enable_all()
.build()
.expect("Failed to create managed tokio runtime")
}
async fn execute_with_runtime_async<F, Fut, T: Send + 'static>(
managed_runtime: &OnceLock<Runtime>,
task: F,
) -> (T, Handle)
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = T> + Send,
{
if let Ok(handle) = Handle::try_current() {
(task().await, handle)
} else {
let runtime = managed_runtime.get_or_init(create_managed_runtime);
runtime
.spawn(async {
let handle = Handle::current();
let result = task().await;
(result, handle)
})
.await
.expect("spawned task panicked from managed runtime")
}
}
pub struct ReqwestAsyncResponse {
response: ReqwestResponse,
current_handle: Handle,
}
impl ReqwestAsyncResponse {
async fn new(
response: reqwest::Response,
max_response_buffer_size: Option<u64>,
current_handle: Handle,
) -> Result<Self> {
Ok(Self {
response: ReqwestResponse::new(response, max_response_buffer_size),
current_handle,
})
}
}
impl AsyncResponse for ReqwestAsyncResponse {
fn describe(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ReqwestAsyncResponse(status: {})", self.status())
}
fn status(&self) -> u16 {
self.response.status()
}
fn content_length(&self) -> Option<u64> {
self.response.content_length()
}
fn get_header(&self, header: &str) -> NyquestResult<Vec<String>> {
self.response.get_header(header)
}
async fn text(self: Pin<&mut Self>) -> NyquestResult<String> {
#[cfg(feature = "charset")]
{
let encoding = self.response.get_best_encoding();
let bytes = AsyncResponse::bytes(self).await?;
let (text, _, _) = encoding.decode(&bytes);
Ok(text.into_owned())
}
#[cfg(not(feature = "charset"))]
{
let bytes = AsyncResponse::bytes(self).await?;
Ok(String::from_utf8_lossy(&bytes).into_owned())
}
}
async fn bytes(mut self: Pin<&mut Self>) -> NyquestResult<Vec<u8>> {
let Self {
current_handle,
response,
} = &mut *self;
let mut task = pin!(response.collect_all_bytes());
poll_fn(|cx| {
let _enter = current_handle.enter();
task.as_mut().poll(cx)
})
.await
}
}
#[cfg(feature = "async-stream")]
impl futures::AsyncRead for ReqwestAsyncResponse {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> std::task::Poll<std::io::Result<usize>> {
use std::task::{ready, Poll};
loop {
let written = self.response.write_to(buf)?;
if written > 0 {
return Poll::Ready(Ok(written));
}
let _enter = self.current_handle.enter();
let received = ready!(self.response.poll_receive_data_frame_buffered(cx))?;
if received == 0 {
break Poll::Ready(Ok(0));
}
}
}
}
impl nyquest_interface::r#async::AsyncBackend for crate::ReqwestBackend {
type AsyncClient = ReqwestAsyncClient;
async fn create_async_client(
&self,
options: nyquest_interface::client::ClientOptions,
) -> NyquestResult<Self::AsyncClient> {
ReqwestAsyncClient::new(options)
}
}