1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
use {super::*, crate::*};

use std::future::Future;
use std::sync::Arc;

use futures::prelude::*;

use tokio::prelude::*;
use tokio::prelude::{AsyncRead, AsyncWrite};

use tokio::sync::Mutex;
use tokio::time::Duration;

// 45 seconds after a receive we'll send a ping
const PING_INACTIVITY: Duration = Duration::from_secs(45);

// and then wait 10 seconds for a pong resposne
const PING_WINDOW: Duration = Duration::from_secs(10);

type BoxFuture<'a, T> = std::pin::Pin<Box<dyn Future<Output = T> + 'a + Send>>;
type ConnectFuture<IO> = BoxFuture<'static, Result<IO, std::io::Error>>;

/// A connector type that acts as a factory for connecting to Twitch
pub struct Connector<IO> {
    connect: Arc<dyn Fn() -> ConnectFuture<IO> + Send + Sync + 'static>,
}

impl<IO> Connector<IO>
where
    IO: AsyncRead + AsyncWrite,
    IO: Send + Sync + 'static,
{
    /// Create a new connector with this factory function
    pub fn new<F, R>(connect_func: F) -> Self
    where
        F: Fn() -> R + Send + Sync + 'static,
        R: Future<Output = Result<IO, std::io::Error>> + Send + Sync + 'static,
    {
        Self {
            connect: Arc::new(move || Box::pin(connect_func())),
        }
    }
}

impl<IO> Clone for Connector<IO> {
    fn clone(&self) -> Self {
        Self {
            connect: Arc::clone(&self.connect),
        }
    }
}

impl<IO> std::fmt::Debug for Connector<IO> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("Connector").finish()
    }
}

/// Some common retry strategies.
///
/// These are used with [`Runner::run_with_retry`][retry].
///
/// You can provide your own by simplying having an async function with the same
/// signature.
///
/// That is `async fn(result: Result<Status, Error>) -> Result<bool, Error>`.
///
/// Return one of:
/// * `Ok(true)` to cause it to reconnect.
/// * `Ok(false)` will gracefully exit with `Ok(Status::Eof)`
/// * `Err(err)` will return that error
#[derive(Copy, Clone, Debug, Default)]
pub struct RetryStrategy;

impl RetryStrategy {
    /// Reconnect immediately unless the `Status` was `Canceled`
    pub async fn immediately(result: Result<Status, Error>) -> Result<bool, Error> {
        if let Ok(Status::Canceled) = result {
            return Ok(false);
        }
        Ok(true)
    }

    /// Retries if `Status` was a **Timeout**, otherwise return the `Err` or `false` (to stop the connection loop).
    pub async fn on_timeout(result: Result<Status, Error>) -> Result<bool, Error> {
        let status = if let Status::Timeout = result? {
            true
        } else {
            false
        };

        Ok(status)
    }

    /// Retries if the `Result` was an error
    pub async fn on_error(result: Result<Status, Error>) -> Result<bool, Error> {
        Ok(result.is_err())
    }
}

/// A type that drive the __event loop__ to completion, and optionally retries on an return condition.
///
/// This type is used to 'drive' the dispatcher and internal read/write futures.
pub struct Runner {
    dispatcher: Dispatcher,
    receiver: Rx,
    writer: Writer,
    abort: abort::Abort,
    ready: Arc<tokio::sync::Notify>,
}

impl Runner {
    /// Create a Runner with the provided dispatcher with the default rate limiter
    ///
    /// # Returns
    /// This returns the [`Runner`] and a [`Control`] type that'll let you interact with the Runner.
    ///
    /// [`Runner`]: ./struct.Runner.html
    /// [`Control`]: ./struct.Control.html
    pub fn new(dispatcher: Dispatcher) -> (Runner, Control) {
        Self::new_with_rate_limit(dispatcher, RateLimit::default())
    }

    /// Create a Runner without a rate limiter.
    ///
    /// # Warning
    /// This is not advisable and goes against the 'rules' of the API.
    ///
    /// You should prefer to use [`Runner::new`](#method.new) to use a default rate limiter
    ///
    /// Or, [`Runner::new_with_rate_limit`](#method.new_with_rate_limit) if you
    /// know your 'bot' status is higher than normal.
    ///
    /// # Returns
    /// This returns the [`Runner`] and a [`Control`] type that'll let you interact with the Runner.
    ///
    /// [`Runner`]: ./struct.Runner.html
    /// [`Control`]: ./struct.Control.html
    pub fn new_without_rate_limit(dispatcher: Dispatcher) -> (Runner, Control) {
        let (sender, receiver) = mpsc::channel(64);
        let stop = abort::Abort::default();
        let writer = Writer::new(crate::encode::AsyncMpscWriter::new(sender));
        let ready = Arc::new(tokio::sync::Notify::default());

        let this = Self {
            dispatcher,
            receiver,
            abort: stop.clone(),
            writer: writer.clone(),
            ready: ready.clone(),
        };

        let control = Control {
            writer,
            stop,
            ready,
        };
        (this, control)
    }

    /// Crate a new Runner with the provided dispatcher and rate limiter
    ///
    /// # Returns
    /// This returns the [`Runner`] and a [`Control`] type that'll let you interact with the Runner.
    ///
    /// [`Runner`]: ./struct.Runner.html
    /// [`Control`]: ./struct.Control.html
    ///
    pub fn new_with_rate_limit(dispatcher: Dispatcher, rate_limit: RateLimit) -> (Runner, Control) {
        let (sender, receiver) = mpsc::channel(64);
        let stop = abort::Abort::default();

        let writer = Writer::new(crate::encode::AsyncMpscWriter::new(sender))
            .with_rate_limiter(Arc::new(Mutex::new(rate_limit)));

        let ready = Arc::new(tokio::sync::Notify::default());

        let this = Self {
            dispatcher,
            receiver,
            abort: stop.clone(),
            writer: writer.clone(),
            ready: ready.clone(),
        };

        let control = Control {
            writer,
            stop,
            ready,
        };

        (this, control)
    }

    /// Run to completion.
    ///
    /// This takes a [`Connector`][connector] which acts a factory for producing IO types.
    ///
    /// This will only call the connector factory once. If you want to reconnect
    /// automatically, refer to [`Runner::run_with_retry`][retry]. That function takes in
    /// a retry strategy for determining how to continue on disconnection.
    ///
    /// The follow happens during the operation of this future
    /// * Connects using the provided [`Connector`][connector]
    /// * Automatically `PING`s the connection when a `PONG` is received
    /// * Checks for timeouts.
    /// * Reads from the IO type, parsing and dispatching messages
    /// * Reads from the writer and forwards it to the IO type
    /// * Listens for user cancellation from the [`Control::stop`][stop] method.
    ///
    /// # Returns a future that resolves to..
    /// * An [error] if one was encountered while in operation
    /// * [`Ok(Status::Eof)`][eof] if it ran to completion
    /// * [`Ok(Status::Canceled)`][cancel] if the associated [`Control::stop`][stop] was called
    ///
    /// [connector]: ./struct.Connector.html
    /// [error]: ./enum.Error.html
    /// [eof]: ./enum.Status.html#variant.Eof
    /// [cancel]: ./enum.Status.html#variant.Canceled
    /// [stop]: ./struct.Control.html#method.stop    
    /// [retry]: #method.run_with_retry
    ///
    pub async fn run_to_completion<IO>(&mut self, connector: Connector<IO>) -> Result<Status, Error>
    where
        IO: AsyncRead + AsyncWrite,
        IO: Unpin + Send + Sync + 'static,
    {
        let io = (connector.connect)().await.map_err(Error::Io)?;

        let mut stream = tokio::io::BufStream::new(io);
        let mut buffer = String::with_capacity(1024);

        let mut ping = self
            .dispatcher
            .subscribe_internal::<crate::events::Ping>(true);

        struct Token(Arc<tokio::sync::Notify>, Arc<tokio::sync::Notify>);
        impl Drop for Token {
            fn drop(&mut self) {
                self.0.notify();
                self.1.notify();
            }
        }

        let restart = Arc::new(tokio::sync::Notify::default());

        // when this drops, the check_connection loop will exit
        let _token = Token(restart.clone(), self.ready.clone());

        let mut out = self.writer.clone();

        // we start a 2nd loop that runs outside of the main loop
        // this sends a ping if we've not sent a message with a window defined by PING_INACTIVITY
        // and if we didn't a PONG response within PING_WINDOW we'll consider the connection stale and exit
        let (mut check_timeout, timeout_delay, timeout_task) =
            check_connection(restart, &self.dispatcher, out.clone());

        loop {
            tokio::select! {
                // Abort notification
                _ = self.abort.wait_for() => {
                    log::debug!("received signal from user to stop");
                    let _ = self.dispatcher.clear_subscriptions_all();
                    break Ok(Status::Canceled)
                }

                // Auto-ping
                Some(msg) = ping.next() => {
                    if out.pong(&msg.token).await.is_err() {
                        log::debug!("cannot send pong");
                        break Ok(Status::Eof);
                    }
                }

                // Read half
                Ok(n) = &mut stream.read_line(&mut buffer) => {
                    if n == 0 {
                        log::info!("read 0 bytes. this is an EOF");
                        break Ok(Status::Eof)
                    }

                    let mut visited = false;
                    for msg in decode(&buffer) {
                        let msg = msg?;
                        log::trace!(target: "twitchchat::runner::read", "< {}", msg.raw.escape_debug());
                        self.dispatcher.dispatch(&msg);
                        visited = true;
                    }

                    // if we didn't parse a message then we should signal that this was EOF
                    // twitch sometimes just stops writing to the client
                    if !visited {
                        log::warn!("twitch sent an incomplete message");
                        break Ok(Status::Eof)
                    }
                    buffer.clear();

                    let _ = check_timeout.send(()).await;
                },

                // Write half
                Some(data) = &mut self.receiver.next() => {
                    log::trace!(target: "twitchchat::runner::write", "> {}", std::str::from_utf8(&data).unwrap().escape_debug());
                    stream.write_all(&data).await?;
                    // flush after each line -- people probably prefer messages sent early
                    stream.flush().await?
                },

                // We received a timeout
                _ = timeout_delay.notified() => {
                    log::warn!(target: "twitchchat::runner::timeout", "timeout detected, quitting loop");
                    // force the loop to exit (we could also use the 'restart' notify here)
                    drop(check_timeout);
                    // and wait for the task to join
                    timeout_task.await;
                    break Ok(Status::Timeout);
                },

                // All of the futures are dead, so the loop should end
                else => {
                    log::info!("all futures are dead. ending loop");
                    break Ok(Status::Eof)
                }
            }
        }
    }

    /// Run to completion and applies a retry functor to the result.
    ///
    /// This takes in a:
    /// * [`Connector`][connector] which acts a factory for producing IO types.
    /// * `retry_check` is a functor from `Result<Status, Error>` to a ___future___ of a `Result<bool, Error>`.
    ///
    /// You can pause in the `retry_check` to cause the next connection attempt to be delayed.
    ///
    /// `retry_check` return values:
    /// * `Ok(true)` will cause this to reconnect.
    /// * `Ok(false)` will cause this to exit with `Ok(Status::Eof)`
    /// * `Err(..)` will cause this to exit with `Err(err)`
    ///
    /// [connector]: ./struct.Connector.html     
    pub async fn run_with_retry<IO, F, R>(
        &mut self,
        connector: Connector<IO>,
        retry_check: F,
    ) -> Result<(), Error>
    where
        IO: AsyncRead + AsyncWrite,
        IO: Unpin + Send + Sync + 'static,

        F: Fn(Result<Status, Error>) -> R,
        F: Send + Sync,
        R: Future<Output = Result<bool, Error>> + Send + Sync + 'static,
        R::Output: Send,
    {
        loop {
            let res = self.run_to_completion(connector.clone()).await;
            match retry_check(res).await {
                Err(err) => break Err(err),
                Ok(false) => break Ok(()),
                Ok(true) => {}
            }

            // reset our internal subscriptions to stop the leak
            self.dispatcher.reset_internal_subscriptions();
        }
    }
}

fn check_connection(
    restart: Arc<tokio::sync::Notify>,
    dispatcher: &Dispatcher,
    mut writer: Writer,
) -> (
    tokio::sync::mpsc::Sender<()>,
    Arc<tokio::sync::Notify>,
    impl Future,
) {
    use tokio::sync::{mpsc, Notify};

    let mut pong = dispatcher.subscribe_internal::<crate::events::Pong>(true);
    let timeout_notify = Arc::new(Notify::new());
    let (tx, mut rx) = mpsc::channel(1);

    let timeout = timeout_notify.clone();
    let task = async move {
        loop {
            tokio::select! {
                // check to see if we've sent a message within the window
                _ = tokio::time::delay_for(PING_INACTIVITY) => {
                    log::debug!(target: "twitchchat::runner::timeout", "inactivity detected of {:?}, sending a ping", PING_INACTIVITY);

                    let ts = std::time::SystemTime::now()
                        .duration_since(std::time::UNIX_EPOCH)
                        .expect("time to not go backwards")
                        .as_secs();

                    // try sending a ping
                    if writer.ping(&format!("{}", ts)).await.is_err() {
                        timeout.notify();
                        log::error!(target: "twitchchat::runner::timeout", "cannot send ping");
                        break;
                    }

                    // and if we didn't get a response in time
                    if tokio::time::timeout(PING_WINDOW, pong.next())
                        .await
                        .is_err()
                    {
                        // exit
                        timeout.notify();
                        log::error!(target: "twitchchat::runner::timeout", "did not get a ping after {:?}", PING_WINDOW);
                        break;
                    }
                }

                // we write something in time, do nothing
                Some(..) = rx.next() => { }

                // when the main loop drops, this is triggered
                _ = restart.notified() => { break }

                else => { break }
            }
        }
    };

    (tx, timeout_notify, tokio::task::spawn(task))
}

impl std::fmt::Debug for Runner {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("Runner").finish()
    }
}