bililive_core/stream/
heartbeat.rs

1use std::marker::PhantomData;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::Waker;
5use std::task::{Context, Poll};
6use std::time::{Duration, Instant};
7
8use futures::ready;
9use futures::{Sink, Stream};
10use log::debug;
11
12use crate::errors::StreamError;
13use crate::packet::{Operation, Packet, Protocol};
14
15use super::waker::WakerProxy;
16
17/// Wrapper that implement heartbeat auto-response mechanism on a [`Packet`](crate::packet::Packet) stream.
18///
19/// Bilibili server requires that every client must respond to a ping packet in 60 seconds. If no
20/// response is sent, the connection will be closed remotely.
21///
22/// `HeartbeatStream` ensures that a pong packet is sent every 30 seconds.
23pub struct HeartbeatStream<T, E> {
24    /// underlying bilibili stream
25    stream: T,
26    /// waker proxy for tx, see WakerProxy for details
27    tx_waker: Arc<WakerProxy>,
28    /// last time when heart beat is sent
29    last_hb: Option<Instant>,
30    __marker: PhantomData<E>,
31}
32
33impl<T: Unpin, E> Unpin for HeartbeatStream<T, E> {}
34
35impl<T, E> HeartbeatStream<T, E> {
36    /// Add heartbeat response mechanism to the underlying bililive stream.
37    pub fn new(stream: T) -> Self {
38        Self {
39            stream,
40            tx_waker: Arc::new(Default::default()),
41            last_hb: None,
42            __marker: PhantomData,
43        }
44    }
45
46    fn with_context<F, U>(&mut self, f: F) -> U
47    where
48        F: FnOnce(&mut Context<'_>, &mut T) -> U,
49    {
50        let waker = Waker::from(self.tx_waker.clone());
51        let mut cx = Context::from_waker(&waker);
52
53        f(&mut cx, &mut self.stream)
54    }
55}
56
57impl<T, E> Stream for HeartbeatStream<T, E>
58where
59    T: Stream<Item = Result<Packet, StreamError<E>>> + Sink<Packet, Error = StreamError<E>> + Unpin,
60    E: std::error::Error,
61{
62    type Item = Result<Packet, StreamError<E>>;
63
64    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
65        // register current task to be waken on poll_ready
66        self.tx_waker.rx(cx.waker());
67
68        // ensure that all pending write op are completed
69        ready!(self.with_context(|cx, s| Pin::new(s).poll_ready(cx)))?;
70
71        // check whether we need to send heartbeat now.
72        let now = Instant::now();
73        let need_hb = self
74            .last_hb
75            .map_or(true, |last_hb| now - last_hb >= Duration::from_secs(30));
76
77        if need_hb {
78            // we need to send heartbeat, so push it into the sink
79            debug!("sending heartbeat");
80            self.as_mut()
81                .start_send(Packet::new(Operation::HeartBeat, Protocol::Json, vec![]))?;
82
83            // Update the time we sent the heartbeat.
84            // It must be earlier than other non-blocking op so that heartbeat
85            // won't be sent repeatedly.
86            self.last_hb = Some(now);
87
88            // Schedule current task to be waken in case there's no incoming
89            // websocket message in a long time.
90            #[cfg(feature = "tokio")]
91            {
92                let waker = cx.waker().clone();
93                tokio1::spawn(async {
94                    tokio1::time::sleep(Duration::from_secs(30)).await;
95                    waker.wake();
96                });
97            }
98            #[cfg(feature = "async-std")]
99            {
100                let waker = cx.waker().clone();
101                async_std1::task::spawn(async {
102                    async_std1::task::sleep(Duration::from_secs(30)).await;
103                    waker.wake();
104                });
105            }
106
107            // ensure that heartbeat is sent
108            ready!(self.with_context(|cx, s| Pin::new(s).poll_flush(cx)))?;
109        }
110
111        Pin::new(&mut self.stream).poll_next(cx)
112    }
113}
114
115impl<T, E> Sink<Packet> for HeartbeatStream<T, E>
116where
117    T: Sink<Packet, Error = StreamError<E>> + Unpin,
118    E: std::error::Error,
119{
120    type Error = StreamError<E>;
121
122    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
123        // wake current task and stream task
124        self.tx_waker.tx(cx.waker());
125
126        // poll the underlying websocket sink
127        self.with_context(|cx, s| Pin::new(s).poll_ready(cx))
128    }
129
130    fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {
131        Pin::new(&mut self.stream).start_send(item)
132    }
133
134    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
135        // wake current task and stream task
136        self.tx_waker.tx(cx.waker());
137
138        // poll the underlying websocket sink
139        self.with_context(|cx, s| Pin::new(s).poll_flush(cx))
140    }
141
142    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
143        // wake current task and stream task
144        self.tx_waker.tx(cx.waker());
145
146        // poll the underlying websocket sink
147        self.with_context(|cx, s| Pin::new(s).poll_close(cx))
148    }
149}