nyquest-backend-reqwest 0.2.1

reqwest backend for nyquest HTTP client library
Documentation
use std::future::{poll_fn, Future};
use std::pin::{pin, Pin};
use std::sync::OnceLock;

use futures::future::{select, Either};
use nyquest_interface::client::ClientOptions;
use nyquest_interface::r#async::{AsyncClient, AsyncResponse, Request};
use nyquest_interface::Result as NyquestResult;
use tokio::runtime::{Handle, Runtime};

use crate::client::ReqwestClient;
use crate::error::{ReqwestBackendError, Result};
use crate::response::ReqwestResponse;

#[cfg(feature = "async-stream")]
mod stream;

#[derive(Clone)]
pub struct ReqwestAsyncClient {
    inner: ReqwestClient,
}

impl ReqwestAsyncClient {
    pub fn new(options: ClientOptions) -> NyquestResult<Self> {
        let inner = ReqwestClient::new(options)?;
        Ok(Self { inner })
    }
}

impl AsyncClient for ReqwestAsyncClient {
    type Response = ReqwestAsyncResponse;

    fn describe(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "ReqwestAsyncClient")
    }

    async fn request(&self, req: Request) -> NyquestResult<Self::Response> {
        #[cfg(feature = "async-stream")]
        let mut stream_task_collection = stream::StreamTaskCollection::default();
        let request_builder = self.inner.request(req, |stream| {
            #[cfg(feature = "async-stream")]
            {
                use nyquest_interface::r#async::BoxedStream;
                let size = match &stream {
                    BoxedStream::Sized { content_length, .. } => Some(*content_length),
                    BoxedStream::Unsized { .. } => None,
                };
                let stream = stream_task_collection.add_stream(stream);
                (reqwest::Body::wrap(stream), size)
            }
            #[cfg(not(feature = "async-stream"))]
            {
                let _ = stream;
                unreachable!("async-stream feature is disabled")
            }
        })?;

        // Execute the request using shared runtime handling
        let req_task = pin!(execute_with_runtime_async(
            &self.inner.managed_runtime,
            || async {
                request_builder
                    .send()
                    .await
                    .map_err(ReqwestBackendError::Reqwest)
            }
        ));
        #[cfg(not(feature = "async-stream"))]
        let stream_task = std::future::pending::<()>();
        #[cfg(feature = "async-stream")]
        let stream_task = pin!(stream_task_collection.execute());
        let (response, handle) = if let Either::Left((res, _)) = select(req_task, stream_task).await
        {
            res
        } else {
            unreachable!()
        };
        ReqwestAsyncResponse::new(response?, self.inner.max_response_buffer_size, handle)
            .await
            .map_err(Into::into)
    }
}

/// Create a new tokio runtime for async operations
fn create_managed_runtime() -> Runtime {
    tokio::runtime::Builder::new_multi_thread()
        .thread_name("nyquest-reqwest-async")
        .worker_threads(1)
        .enable_all()
        .build()
        .expect("Failed to create managed tokio runtime")
}

/// Execute an async task with proper runtime handling for async context
async fn execute_with_runtime_async<F, Fut, T: Send + 'static>(
    managed_runtime: &OnceLock<Runtime>,
    task: F,
) -> (T, Handle)
where
    F: FnOnce() -> Fut + Send + 'static,
    Fut: Future<Output = T> + Send,
{
    if let Ok(handle) = Handle::try_current() {
        // Inside tokio runtime - proceed normally
        (task().await, handle)
    } else {
        // Outside tokio runtime - use managed runtime with block_on
        let runtime = managed_runtime.get_or_init(create_managed_runtime);
        runtime
            .spawn(async {
                let handle = Handle::current();
                let result = task().await;
                (result, handle)
            })
            .await
            .expect("spawned task panicked from managed runtime")
    }
}

pub struct ReqwestAsyncResponse {
    response: ReqwestResponse,
    current_handle: Handle,
}

impl ReqwestAsyncResponse {
    async fn new(
        response: reqwest::Response,
        max_response_buffer_size: Option<u64>,
        current_handle: Handle,
    ) -> Result<Self> {
        Ok(Self {
            response: ReqwestResponse::new(response, max_response_buffer_size),
            current_handle,
        })
    }
}

impl AsyncResponse for ReqwestAsyncResponse {
    fn describe(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "ReqwestAsyncResponse(status: {})", self.status())
    }

    fn status(&self) -> u16 {
        self.response.status()
    }

    fn content_length(&self) -> Option<u64> {
        self.response.content_length()
    }

    fn get_header(&self, header: &str) -> NyquestResult<Vec<String>> {
        self.response.get_header(header)
    }

    async fn text(self: Pin<&mut Self>) -> NyquestResult<String> {
        #[cfg(feature = "charset")]
        {
            let encoding = self.response.get_best_encoding();
            let bytes = AsyncResponse::bytes(self).await?;
            let (text, _, _) = encoding.decode(&bytes);
            Ok(text.into_owned())
        }

        #[cfg(not(feature = "charset"))]
        {
            let bytes = AsyncResponse::bytes(self).await?;
            Ok(String::from_utf8_lossy(&bytes).into_owned())
        }
    }

    async fn bytes(mut self: Pin<&mut Self>) -> NyquestResult<Vec<u8>> {
        let Self {
            current_handle,
            response,
        } = &mut *self;
        let mut task = pin!(response.collect_all_bytes());
        poll_fn(|cx| {
            let _enter = current_handle.enter();
            task.as_mut().poll(cx)
        })
        .await
    }
}

#[cfg(feature = "async-stream")]
impl futures::AsyncRead for ReqwestAsyncResponse {
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &mut [u8],
    ) -> std::task::Poll<std::io::Result<usize>> {
        use std::task::{ready, Poll};

        loop {
            let written = self.response.write_to(buf)?;
            if written > 0 {
                return Poll::Ready(Ok(written));
            }
            let _enter = self.current_handle.enter();
            let received = ready!(self.response.poll_receive_data_frame_buffered(cx))?;
            if received == 0 {
                break Poll::Ready(Ok(0));
            }
        }
    }
}

impl nyquest_interface::r#async::AsyncBackend for crate::ReqwestBackend {
    type AsyncClient = ReqwestAsyncClient;

    async fn create_async_client(
        &self,
        options: nyquest_interface::client::ClientOptions,
    ) -> NyquestResult<Self::AsyncClient> {
        ReqwestAsyncClient::new(options)
    }
}