use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
use std::time::{Duration, Instant};
use async_tungstenite::tungstenite::{error::Error as WsError, Message};
use futures::{ready, Sink, Stream};
use log::{debug, warn};
use crate::errors::{BililiveError, IncompleteResult};
use crate::packet::{Operation, Packet, Protocol};
use self::waker::WakerProxy;
pub(crate) mod retry;
mod utils;
mod waker;
#[cfg(test)]
mod tests;
type StreamResult<T> = std::result::Result<T, BililiveError>;
pub struct BililiveStream<T> {
stream: T,
tx_waker: Arc<WakerProxy>,
last_hb: Option<Instant>,
read_buffer: Vec<u8>,
}
impl<T> BililiveStream<T> {
pub fn from_raw_stream(stream: T) -> Self {
Self {
stream,
tx_waker: Arc::new(Default::default()),
last_hb: None,
read_buffer: vec![],
}
}
fn with_context<F, U>(&mut self, f: F) -> U
where
F: FnOnce(&mut Context<'_>, &mut T) -> U,
{
let waker = Waker::from(self.tx_waker.clone());
let mut cx = Context::from_waker(&waker);
f(&mut cx, &mut self.stream)
}
}
impl<T> Stream for BililiveStream<T>
where
T: Stream<Item = Result<Message, WsError>> + Sink<Message, Error = WsError> + Unpin,
{
type Item = StreamResult<Packet>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.tx_waker.rx(cx.waker());
ready!(self.with_context(|cx, s| Pin::new(s).poll_ready(cx)))?;
let now = Instant::now();
let need_hb = self
.last_hb
.map_or(true, |last_hb| now - last_hb >= Duration::from_secs(30));
if need_hb {
debug!("sending heartbeat");
self.as_mut()
.start_send(Packet::new(Operation::HeartBeat, Protocol::Json, vec![]))?;
self.last_hb = Some(now);
#[cfg(feature = "tokio")]
{
let waker = cx.waker().clone();
tokio::spawn(async {
tokio::time::sleep(Duration::from_secs(30)).await;
waker.wake();
});
}
#[cfg(feature = "async-std")]
{
let waker = cx.waker().clone();
async_std::task::spawn(async {
async_std::task::sleep(Duration::from_secs(30)).await;
waker.wake();
});
}
ready!(self.with_context(|cx, s| Pin::new(s).poll_flush(cx)))?;
}
loop {
if let Some(msg) = ready!(Pin::new(&mut self.stream).poll_next(cx)) {
match msg {
Ok(msg) => {
if msg.is_binary() {
self.read_buffer.extend(msg.into_data());
match Packet::parse(&self.read_buffer) {
IncompleteResult::Ok((remaining, pack)) => {
debug!("packet parsed, {} bytes remaining", remaining.len());
let consume_len = self.read_buffer.len() - remaining.len();
drop(self.read_buffer.drain(..consume_len));
return Poll::Ready(Some(Ok(pack)));
}
IncompleteResult::Incomplete(needed) => {
debug!("incomplete packet, {:?} needed", needed);
}
IncompleteResult::Err(e) => {
warn!("error occurred when parsing incoming packet");
return Poll::Ready(Some(Err(e)));
}
}
} else {
debug!("not a binary message, dropping");
}
}
Err(e) => {
warn!("error occurred when receiving message: {:?}", e);
return Poll::Ready(None);
}
}
} else {
return Poll::Ready(None);
}
}
}
}
impl<T> Sink<Packet> for BililiveStream<T>
where
T: Sink<Message, Error = WsError> + Unpin,
{
type Error = BililiveError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx_waker.tx(cx.waker());
let waker = Waker::from(self.tx_waker.clone());
let mut cx = Context::from_waker(&waker);
Poll::Ready(Ok(ready!(Pin::new(&mut self.stream).poll_ready(&mut cx))?))
}
fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {
Ok(Pin::new(&mut self.stream).start_send(Message::binary(item.encode()))?)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx_waker.tx(cx.waker());
let waker = Waker::from(self.tx_waker.clone());
let mut cx = Context::from_waker(&waker);
Poll::Ready(Ok(ready!(Pin::new(&mut self.stream).poll_flush(&mut cx))?))
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx_waker.tx(cx.waker());
let waker = Waker::from(self.tx_waker.clone());
let mut cx = Context::from_waker(&waker);
Poll::Ready(Ok(ready!(Pin::new(&mut self.stream).poll_close(&mut cx))?))
}
}