bililive_core/stream/
heartbeat.rs1use std::marker::PhantomData;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::Waker;
5use std::task::{Context, Poll};
6use std::time::{Duration, Instant};
7
8use futures::ready;
9use futures::{Sink, Stream};
10use log::debug;
11
12use crate::errors::StreamError;
13use crate::packet::{Operation, Packet, Protocol};
14
15use super::waker::WakerProxy;
16
17pub struct HeartbeatStream<T, E> {
24 stream: T,
26 tx_waker: Arc<WakerProxy>,
28 last_hb: Option<Instant>,
30 __marker: PhantomData<E>,
31}
32
33impl<T: Unpin, E> Unpin for HeartbeatStream<T, E> {}
34
35impl<T, E> HeartbeatStream<T, E> {
36 pub fn new(stream: T) -> Self {
38 Self {
39 stream,
40 tx_waker: Arc::new(Default::default()),
41 last_hb: None,
42 __marker: PhantomData,
43 }
44 }
45
46 fn with_context<F, U>(&mut self, f: F) -> U
47 where
48 F: FnOnce(&mut Context<'_>, &mut T) -> U,
49 {
50 let waker = Waker::from(self.tx_waker.clone());
51 let mut cx = Context::from_waker(&waker);
52
53 f(&mut cx, &mut self.stream)
54 }
55}
56
57impl<T, E> Stream for HeartbeatStream<T, E>
58where
59 T: Stream<Item = Result<Packet, StreamError<E>>> + Sink<Packet, Error = StreamError<E>> + Unpin,
60 E: std::error::Error,
61{
62 type Item = Result<Packet, StreamError<E>>;
63
64 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
65 self.tx_waker.rx(cx.waker());
67
68 ready!(self.with_context(|cx, s| Pin::new(s).poll_ready(cx)))?;
70
71 let now = Instant::now();
73 let need_hb = self
74 .last_hb
75 .map_or(true, |last_hb| now - last_hb >= Duration::from_secs(30));
76
77 if need_hb {
78 debug!("sending heartbeat");
80 self.as_mut()
81 .start_send(Packet::new(Operation::HeartBeat, Protocol::Json, vec![]))?;
82
83 self.last_hb = Some(now);
87
88 #[cfg(feature = "tokio")]
91 {
92 let waker = cx.waker().clone();
93 tokio1::spawn(async {
94 tokio1::time::sleep(Duration::from_secs(30)).await;
95 waker.wake();
96 });
97 }
98 #[cfg(feature = "async-std")]
99 {
100 let waker = cx.waker().clone();
101 async_std1::task::spawn(async {
102 async_std1::task::sleep(Duration::from_secs(30)).await;
103 waker.wake();
104 });
105 }
106
107 ready!(self.with_context(|cx, s| Pin::new(s).poll_flush(cx)))?;
109 }
110
111 Pin::new(&mut self.stream).poll_next(cx)
112 }
113}
114
115impl<T, E> Sink<Packet> for HeartbeatStream<T, E>
116where
117 T: Sink<Packet, Error = StreamError<E>> + Unpin,
118 E: std::error::Error,
119{
120 type Error = StreamError<E>;
121
122 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
123 self.tx_waker.tx(cx.waker());
125
126 self.with_context(|cx, s| Pin::new(s).poll_ready(cx))
128 }
129
130 fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {
131 Pin::new(&mut self.stream).start_send(item)
132 }
133
134 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
135 self.tx_waker.tx(cx.waker());
137
138 self.with_context(|cx, s| Pin::new(s).poll_flush(cx))
140 }
141
142 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
143 self.tx_waker.tx(cx.waker());
145
146 self.with_context(|cx, s| Pin::new(s).poll_close(cx))
148 }
149}