rosu-render 0.5.1

An o!rdr wrapper to render osu! replays
Documentation
use std::{
    future::Future,
    marker::PhantomData,
    pin::{pin, Pin},
    task::{Context, Poll},
};

use http_body_util::{combinators::Collect, BodyExt};
use hyper::{body::Incoming, StatusCode};
use hyper_util::client::legacy::ResponseFuture;
use leaky_bucket::AcquireOwned;
use pin_project::pin_project;
use serde::de::DeserializeOwned;

use crate::ClientError;

use super::requestable::Requestable;

#[pin_project(project = OrdrFutureProj)]
pub struct OrdrFuture<T> {
    #[pin]
    ratelimit: Option<AcquireOwned>,
    #[pin]
    state: OrdrFutureState<T>,
}

impl<T> OrdrFuture<T> {
    pub(crate) const fn new(fut: Pin<Box<ResponseFuture>>, ratelimit: AcquireOwned) -> Self {
        Self {
            ratelimit: Some(ratelimit),
            state: OrdrFutureState::InFlight(InFlight {
                fut,
                phantom: PhantomData,
            }),
        }
    }

    pub(crate) const fn error(source: ClientError) -> Self {
        Self {
            ratelimit: None,
            state: OrdrFutureState::Failed(Some(source)),
        }
    }

    fn await_ratelimit(
        mut ratelimit_opt: Pin<&mut Option<AcquireOwned>>,
        cx: &mut Context<'_>,
    ) -> Poll<()> {
        if let Some(ratelimit) = ratelimit_opt.as_mut().as_pin_mut() {
            match ratelimit.poll(cx) {
                Poll::Ready(()) => ratelimit_opt.set(None),
                Poll::Pending => return Poll::Pending,
            }
        }

        Poll::Ready(())
    }
}

impl<T: DeserializeOwned + Requestable> Future for OrdrFuture<T> {
    type Output = Result<T, ClientError>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let mut this = self.project();
        let mut state = this.state.as_mut();

        match state.as_mut().project() {
            OrdrFutureStateProj::InFlight(in_flight) => {
                if Self::await_ratelimit(this.ratelimit, cx).is_pending() {
                    return Poll::Pending;
                }

                match in_flight.poll(cx) {
                    Poll::Ready(Ok(chunking)) => {
                        state.set(OrdrFutureState::Chunking(chunking));
                        cx.waker().wake_by_ref();

                        Poll::Pending
                    }
                    Poll::Ready(Err(err)) => {
                        state.set(OrdrFutureState::Completed);

                        Poll::Ready(Err(err))
                    }
                    Poll::Pending => Poll::Pending,
                }
            }
            OrdrFutureStateProj::Chunking(chunking) => match chunking.poll(cx) {
                Poll::Ready(res) => {
                    state.set(OrdrFutureState::Completed);

                    Poll::Ready(res)
                }
                Poll::Pending => Poll::Pending,
            },
            OrdrFutureStateProj::Failed(failed) => {
                let err = failed.take().expect("error already taken");
                state.set(OrdrFutureState::Completed);

                Poll::Ready(Err(err))
            }
            OrdrFutureStateProj::Completed => panic!("future already completed"),
        }
    }
}

#[pin_project(project = OrdrFutureStateProj)]
enum OrdrFutureState<T> {
    Chunking(#[pin] Chunking<T>),
    Completed,
    Failed(Option<ClientError>),
    InFlight(#[pin] InFlight<T>),
}

#[pin_project]
struct Chunking<T> {
    #[pin]
    fut: Collect<Incoming>,
    status: StatusCode,
    phantom: PhantomData<T>,
}

impl<T: DeserializeOwned + Requestable> Future for Chunking<T> {
    type Output = Result<T, ClientError>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.project();

        let bytes = match this.fut.poll(cx) {
            Poll::Ready(Ok(collected)) => collected.to_bytes(),
            Poll::Ready(Err(source)) => {
                return Poll::Ready(Err(ClientError::ChunkingResponse { source }))
            }
            Poll::Pending => return Poll::Pending,
        };

        let res = if this.status.is_success() {
            match serde_json::from_slice(&bytes) {
                Ok(this) => Ok(this),
                Err(source) => Err(ClientError::Parsing {
                    body: bytes.into(),
                    source,
                }),
            }
        } else {
            Err(<T as Requestable>::response_error(*this.status, bytes))
        };

        Poll::Ready(res)
    }
}

#[pin_project]
struct InFlight<T> {
    #[pin]
    fut: Pin<Box<ResponseFuture>>,
    phantom: PhantomData<T>,
}

impl<T: Requestable> Future for InFlight<T> {
    type Output = Result<Chunking<T>, ClientError>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.project();

        let response = match this.fut.poll(cx) {
            Poll::Ready(Ok(response)) => response,
            Poll::Ready(Err(source)) => {
                return Poll::Ready(Err(ClientError::RequestError { source }))
            }
            Poll::Pending => return Poll::Pending,
        };

        let status = response.status();

        match status {
            StatusCode::TOO_MANY_REQUESTS => warn!("429 response: {response:?}"),
            StatusCode::SERVICE_UNAVAILABLE => {
                return Poll::Ready(Err(ClientError::ServiceUnavailable {
                    response: Box::new(response),
                }))
            }
            _ => {}
        }

        Poll::Ready(Ok(Chunking {
            fut: response.into_body().collect(),
            status,
            phantom: PhantomData,
        }))
    }
}