1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
//! JSONRPC server on any streams, e.g. TCP, unix socket.
//!
//! Use `tokio_util::codec` to convert `AsyncRead`, `AsyncWrite` to `Stream`
//! and `Sink`. Use `LinesCodec` or define you own codec.

use std::{sync::atomic::AtomicU64, time::Duration};

use futures_core::Stream;
use futures_util::{Sink, SinkExt, StreamExt};
use jsonrpc_core::{MetaIoHandler, Metadata};
use tokio::{sync::mpsc::channel, time::Instant};

use crate::pub_sub::Session;

#[derive(Clone)]
pub struct StreamServerConfig {
    pub(crate) channel_size: usize,
    pub(crate) pipeline_size: usize,
    pub(crate) keep_alive: bool,
    pub(crate) keep_alive_duration: Duration,
    pub(crate) ping_interval: Duration,
}

impl Default for StreamServerConfig {
    fn default() -> Self {
        Self {
            channel_size: 8,
            pipeline_size: 1,
            keep_alive: false,
            keep_alive_duration: Duration::from_secs(60),
            ping_interval: Duration::from_secs(19),
        }
    }
}

impl StreamServerConfig {
    /// Set websocket channel size.
    ///
    /// Default is 8.
    ///
    /// # Panics
    ///
    /// If channel_size is 0.
    pub fn with_channel_size(mut self, channel_size: usize) -> Self {
        assert!(channel_size > 0);
        self.channel_size = channel_size;
        self
    }

    /// Set maximum request pipelining.
    ///
    /// Up to `pipeline_size` number of requests will be handled concurrently.
    ///
    /// Default is 1, i.e. no pipelining.
    ///
    /// # Panics
    ///
    /// if `pipeline_size` is 0.
    pub fn with_pipeline_size(mut self, pipeline_size: usize) -> Self {
        assert!(pipeline_size > 0);
        self.pipeline_size = pipeline_size;
        self
    }

    /// Set whether keep alive is enabled.
    ///
    /// Default is false.
    pub fn with_keep_alive(mut self, keep_alive: bool) -> Self {
        self.keep_alive = keep_alive;
        self
    }

    /// Wait for `keep_alive_duration` after the last message is received, then
    /// close the connection.
    ///
    /// Default is 60 seconds.
    pub fn with_keep_alive_duration(mut self, keep_alive_duration: Duration) -> Self {
        self.keep_alive_duration = keep_alive_duration;
        self
    }

    /// Set interval to send ping messages.
    ///
    /// Default is 19 seconds.
    pub fn with_ping_interval(mut self, ping_interval: Duration) -> Self {
        self.ping_interval = ping_interval;
        self
    }
}

pub enum StreamMsg {
    Str(String),
    Ping,
    Pong,
}

/// Serve JSON-RPC requests over a bidirectional stream (Stream + Sink).
///
/// # Keepalive
///
/// TODO: document keepalive mechanism.
pub async fn serve_stream_sink<E, T: Metadata + From<Session>>(
    rpc: &MetaIoHandler<T>,
    mut sink: impl Sink<StreamMsg, Error = E> + Unpin,
    stream: impl Stream<Item = Result<StreamMsg, E>> + Unpin,
    config: StreamServerConfig,
) -> Result<(), E> {
    static SESSION_ID: AtomicU64 = AtomicU64::new(0);

    let (tx, mut rx) = channel(config.channel_size);
    let session = Session {
        id: SESSION_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
        raw_tx: tx,
    };

    let dead_timer = tokio::time::sleep(config.keep_alive_duration);
    tokio::pin!(dead_timer);
    let mut ping_interval = tokio::time::interval(config.ping_interval);
    ping_interval.reset();
    ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);

    let mut result_stream = stream
        .map(|message_or_err| async {
            let msg = message_or_err?;
            let msg = match msg {
                StreamMsg::Str(msg) => msg,
                _ => return Ok(None),
            };
            Ok::<_, E>(rpc.handle_request(&msg, session.clone().into()).await)
        })
        .buffer_unordered(config.pipeline_size);
    loop {
        tokio::select! {
            result = result_stream.next() => {
                match result {
                    Some(result) => {
                        if let Some(result) = result? {
                            sink.send(StreamMsg::Str(result)).await?;
                        }
                        if config.keep_alive {
                            dead_timer
                                .as_mut()
                                .reset(Instant::now() + config.keep_alive_duration);
                        }
                    }
                    _ => break,
                }
            }
            // This will never be None.
            Some(msg) = rx.recv() => {
                sink.send(StreamMsg::Str(msg)).await?;
            }
            _ = &mut dead_timer, if config.keep_alive => {
                break;
            }
            _ = ping_interval.tick(), if config.keep_alive => {
                sink.send(StreamMsg::Ping).await?;
            }
        }
    }
    Ok(())
}