misskey_websocket/
broker.rs

1use std::fmt::{self, Debug};
2use std::sync::Arc;
3use std::time::Duration;
4
5use crate::channel::{connect_websocket, TrySendError, WebSocketReceiver};
6use crate::error::{Error, Result};
7use crate::model::outgoing::OutgoingMessage;
8
9#[cfg(feature = "async-tungstenite09")]
10use async_tungstenite09 as async_tungstenite;
11
12#[cfg(feature = "async-std-runtime")]
13use async_std::task;
14#[cfg(feature = "async-std-runtime")]
15use async_std::task::sleep;
16use async_tungstenite::tungstenite::Error as WsError;
17use futures::stream::StreamExt;
18use log::{info, warn};
19#[cfg(feature = "tokio-runtime")]
20use tokio::task;
21#[cfg(feature = "tokio-runtime")]
22use tokio::time::sleep;
23#[cfg(feature = "tokio02-runtime")]
24use tokio02::task;
25#[cfg(feature = "tokio02-runtime")]
26use tokio02::time::delay_for as sleep;
27use url::Url;
28
29pub mod channel;
30pub mod handler;
31pub mod model;
32
33use channel::{control_channel, ControlReceiver, ControlSender};
34use handler::Handler;
35use model::SharedBrokerState;
36
37#[derive(Debug)]
38pub(crate) struct Broker {
39    broker_rx: ControlReceiver,
40    handler: Handler,
41    reconnect: ReconnectConfig,
42    url: Url,
43}
44
45/// Specifies the condition for reconnecting.
46#[derive(Clone)]
47pub struct ReconnectCondition {
48    inner: ReconnectConditionKind,
49}
50
51impl Debug for ReconnectCondition {
52    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
53        f.debug_tuple("ReconnectCondition")
54            .field(&self.inner)
55            .finish()
56    }
57}
58
59#[derive(Clone)]
60enum ReconnectConditionKind {
61    Always,
62    Never,
63    UnexpectedReset,
64    // Using `Arc` instead of `Box` not to lose `Clone` for the infrequest use of `Custom` variant.
65    Custom(Arc<dyn Fn(&Error) -> bool + Send + Sync + 'static>),
66}
67
68impl Debug for ReconnectConditionKind {
69    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
70        match self {
71            ReconnectConditionKind::Always => f.debug_tuple("Always").finish(),
72            ReconnectConditionKind::Never => f.debug_tuple("Never").finish(),
73            ReconnectConditionKind::UnexpectedReset => f.debug_tuple("UnexpectedReset").finish(),
74            ReconnectConditionKind::Custom(_) => f.debug_tuple("Custom").finish(),
75        }
76    }
77}
78
79impl ReconnectCondition {
80    /// Creates a `ReconnectCondition` that reconnects regardless of the errors.
81    pub fn always() -> Self {
82        ReconnectCondition {
83            inner: ReconnectConditionKind::Always,
84        }
85    }
86
87    /// Creates a `ReconnectCondition` that does not reconnect regardless of the errors.
88    pub fn never() -> Self {
89        ReconnectCondition {
90            inner: ReconnectConditionKind::Never,
91        }
92    }
93
94    /// Creates a `ReconnectCondition` that reconnects when the connection is lost unexpectedly.
95    pub fn unexpected_reset() -> Self {
96        ReconnectCondition {
97            inner: ReconnectConditionKind::UnexpectedReset,
98        }
99    }
100
101    /// Creates a custom `ReconnectCondition` using the passed function.
102    pub fn custom<F>(f: F) -> Self
103    where
104        F: Fn(&Error) -> bool + Send + Sync + 'static,
105    {
106        ReconnectCondition {
107            inner: ReconnectConditionKind::Custom(Arc::new(f)),
108        }
109    }
110
111    fn should_reconnect(&self, err: &Error) -> bool {
112        match &self.inner {
113            ReconnectConditionKind::Always => true,
114            ReconnectConditionKind::Never => false,
115            ReconnectConditionKind::UnexpectedReset => {
116                let ws = match err {
117                    Error::WebSocket(ws) => ws,
118                    _ => return false,
119                };
120
121                use std::io::ErrorKind;
122                match ws.as_ref() {
123                    WsError::Protocol(_) => true,
124                    WsError::Io(e) => {
125                        e.kind() == ErrorKind::ConnectionReset || e.kind() == ErrorKind::BrokenPipe
126                    }
127                    _ => false,
128                }
129            }
130            ReconnectConditionKind::Custom(f) => f(err),
131        }
132    }
133}
134
135impl Default for ReconnectCondition {
136    /// [`unexpected_reset()`][`ReconnectCondition::unexpected_reset`] is used as a default.
137    fn default() -> ReconnectCondition {
138        ReconnectCondition::unexpected_reset()
139    }
140}
141
142/// Reconnection configuration.
143#[derive(Debug, Clone)]
144pub struct ReconnectConfig {
145    /// Sets an interval duration of automatic reconnection.
146    pub interval: Duration,
147    /// Specifies the condition for reconnecting.
148    pub condition: ReconnectCondition,
149    /// Specifies whether to re-send messages that may have failed to be sent when reconnecting.
150    pub retry_send: bool,
151}
152
153impl ReconnectConfig {
154    /// Creates a `ReconnectConfig` that disables reconnection.
155    pub fn none() -> ReconnectConfig {
156        ReconnectConfig::with_condition(ReconnectCondition::never())
157    }
158
159    /// Creates a `ReconnectConfig` with the given condition.
160    pub fn with_condition(condition: ReconnectCondition) -> ReconnectConfig {
161        ReconnectConfig {
162            condition,
163            ..Default::default()
164        }
165    }
166
167    /// Creates a `ReconnectConfig` with the given interval.
168    pub fn with_interval(interval: Duration) -> ReconnectConfig {
169        ReconnectConfig {
170            interval,
171            ..Default::default()
172        }
173    }
174}
175
176impl Default for ReconnectConfig {
177    /// `interval` is 5 secs and `retry_send` is `true` by default.
178    fn default() -> ReconnectConfig {
179        ReconnectConfig {
180            interval: Duration::from_secs(5),
181            condition: ReconnectCondition::default(),
182            retry_send: true,
183        }
184    }
185}
186
187impl Broker {
188    pub async fn spawn(
189        url: Url,
190        reconnect: ReconnectConfig,
191    ) -> Result<(ControlSender, SharedBrokerState)> {
192        let state = SharedBrokerState::working();
193        let shared_state = SharedBrokerState::clone(&state);
194
195        let (broker_tx, broker_rx) = control_channel(SharedBrokerState::clone(&state));
196
197        task::spawn(async move {
198            let mut broker = Broker {
199                url,
200                broker_rx,
201                reconnect,
202                handler: Handler::new(),
203            };
204
205            if let Some(err) = broker.run().await {
206                state.set_error(err).await;
207            } else {
208                state.set_exited().await;
209            }
210
211            // This ensures that broker (and communication channels on broker side)
212            // is dropped after `state` is surely set to `Dead` or `Exited`, thus asserts that the
213            // state must be set to `Dead` or `Exited` when these channels are found out to be closed.
214            std::mem::drop(broker);
215        });
216
217        Ok((broker_tx, shared_state))
218    }
219
220    async fn run(&mut self) -> Option<Error> {
221        let mut remaining_message = None;
222
223        loop {
224            let err = match self.task(remaining_message.take()).await {
225                Ok(()) => {
226                    info!("broker: exited normally");
227                    return None;
228                }
229                Err(e) => e,
230            };
231
232            info!("broker: task exited with error: {:?}", err.error);
233
234            if !self.reconnect.condition.should_reconnect(&err.error) {
235                warn!("broker: died with error");
236                return Some(err.error);
237            }
238
239            if self.reconnect.retry_send {
240                remaining_message = err.remaining_message;
241            }
242
243            info!(
244                "broker: attempt to reconnect in {:?}",
245                self.reconnect.interval
246            );
247            sleep(self.reconnect.interval).await;
248        }
249    }
250
251    async fn clean_handler(&mut self, websocket_rx: &mut WebSocketReceiver) -> Result<()> {
252        if self.handler.is_empty() {
253            return Ok(());
254        }
255
256        info!("broker: handler is not empty, enter receiving loop");
257        while !self.handler.is_empty() {
258            let msg = websocket_rx.recv().await?;
259            self.handler.handle(msg).await?;
260        }
261
262        Ok(())
263    }
264
265    async fn task(
266        &mut self,
267        remaining_message: Option<OutgoingMessage>,
268    ) -> std::result::Result<(), TaskError> {
269        use futures::future::{self, Either};
270
271        let (mut websocket_tx, mut websocket_rx) = match connect_websocket(self.url.clone()).await {
272            Ok(x) => x,
273            Err(error) => {
274                // retain `remaining_message` because we've not sent it yet
275                return Err(TaskError {
276                    remaining_message,
277                    error,
278                });
279            }
280        };
281
282        info!("broker: started");
283
284        if let Some(message) = remaining_message {
285            websocket_tx.try_send(message).await?;
286        }
287
288        for message in self.handler.restore_messages() {
289            websocket_tx.try_send(message).await?;
290        }
291
292        loop {
293            let t1 = websocket_rx.recv();
294            let t2 = self.broker_rx.next();
295
296            futures::pin_mut!(t1, t2);
297
298            match future::select(t1, t2).await {
299                Either::Left((msg, _)) => {
300                    while let Some(ctrl) = self.broker_rx.try_recv() {
301                        #[cfg(feature = "inspect-contents")]
302                        log::debug!("broker: received control {:?}", ctrl);
303
304                        if let Some(out) = self.handler.control(ctrl) {
305                            websocket_tx.try_send(out).await?
306                        }
307                    }
308
309                    self.handler.handle(msg?).await?;
310                }
311                Either::Right((Some(ctrl), _)) => {
312                    #[cfg(feature = "inspect-contents")]
313                    log::debug!("broker: received control {:?}", ctrl);
314
315                    if let Some(out) = self.handler.control(ctrl) {
316                        websocket_tx.try_send(out).await?
317                    }
318                }
319                Either::Right((None, _)) => {
320                    info!("broker: all controls terminated, exiting gracefully");
321                    return Ok(self.clean_handler(&mut websocket_rx).await?);
322                }
323            }
324        }
325    }
326}
327
328#[derive(Debug, Clone)]
329struct TaskError {
330    remaining_message: Option<OutgoingMessage>,
331    error: Error,
332}
333
334impl From<Error> for TaskError {
335    fn from(error: Error) -> TaskError {
336        TaskError {
337            remaining_message: None,
338            error,
339        }
340    }
341}
342
343impl From<TrySendError> for TaskError {
344    fn from(err: TrySendError) -> TaskError {
345        let TrySendError { message, error } = err;
346        TaskError {
347            remaining_message: Some(message),
348            error,
349        }
350    }
351}