use std::{
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use async_trait::async_trait;
use bytes::Bytes;
use futures::Stream;
use kithara_platform::{MaybeSend, MaybeSync};
use tokio_util::sync::CancellationToken;
use url::Url;
mod kithara {
pub(crate) use kithara_test_macros::mock;
}
use crate::{
error::NetError,
retry::{DefaultRetryPolicy, RetryNet},
timeout::TimeoutNet,
types::{Headers, RangeSpec, RetryPolicy},
};
#[cfg(not(target_arch = "wasm32"))]
type RawByteStream = Pin<Box<dyn Stream<Item = Result<Bytes, NetError>> + Send>>;
#[cfg(target_arch = "wasm32")]
type RawByteStream = Pin<Box<dyn Stream<Item = Result<Bytes, NetError>>>>;
pub struct ByteStream {
pub headers: Headers,
inner: RawByteStream,
}
impl ByteStream {
#[must_use]
pub fn new(headers: Headers, inner: RawByteStream) -> Self {
Self { headers, inner }
}
#[must_use]
pub fn into_inner(self) -> RawByteStream {
self.inner
}
#[must_use]
pub fn without_headers(inner: RawByteStream) -> Self {
Self {
inner,
headers: Headers::new(),
}
}
}
impl Stream for ByteStream {
type Item = Result<Bytes, NetError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.get_mut().inner.as_mut().poll_next(cx)
}
}
#[cfg_attr(not(target_arch = "wasm32"), kithara::mock(api = NetMock))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
pub trait Net: MaybeSend + MaybeSync {
async fn get_bytes(&self, url: Url, headers: Option<Headers>) -> Result<Bytes, NetError>;
async fn get_range(
&self,
url: Url,
range: RangeSpec,
headers: Option<Headers>,
) -> Result<ByteStream, NetError>;
async fn head(&self, url: Url, headers: Option<Headers>) -> Result<Headers, NetError>;
async fn stream(&self, url: Url, headers: Option<Headers>) -> Result<ByteStream, NetError>;
}
pub trait NetExt: Net + Sized {
fn with_retry(
self,
policy: RetryPolicy,
cancel: CancellationToken,
) -> RetryNet<Self, DefaultRetryPolicy> {
RetryNet::new(self, DefaultRetryPolicy::new(policy), cancel)
}
fn with_timeout(self, timeout: Duration) -> TimeoutNet<Self> {
TimeoutNet::new(self, timeout)
}
}
impl<T: Net> NetExt for T {}