use std::{ops::Deref, sync::atomic::AtomicU64, time::Duration};
use futures_core::{future::BoxFuture, Future, Stream};
use futures_util::{
future::{self, Shared},
FutureExt, Sink, SinkExt, StreamExt,
};
use jsonrpc_core::{MetaIoHandler, Metadata};
use tokio::{sync::mpsc::channel, time::Instant};
use crate::pub_sub::Session;
#[derive(Clone)]
pub struct StreamServerConfig {
pub(crate) channel_size: usize,
pub(crate) pipeline_size: usize,
pub(crate) keep_alive: bool,
pub(crate) keep_alive_duration: Duration,
pub(crate) ping_interval: Duration,
pub(crate) shutdown_signal: Shared<BoxFuture<'static, ()>>,
}
impl Default for StreamServerConfig {
fn default() -> Self {
Self {
channel_size: 8,
pipeline_size: 1,
keep_alive: false,
keep_alive_duration: Duration::from_secs(60),
ping_interval: Duration::from_secs(19),
shutdown_signal: future::pending().boxed().shared(),
}
}
}
impl StreamServerConfig {
pub fn with_channel_size(mut self, channel_size: usize) -> Self {
assert!(channel_size > 0);
self.channel_size = channel_size;
self
}
pub fn with_pipeline_size(mut self, pipeline_size: usize) -> Self {
assert!(pipeline_size > 0);
self.pipeline_size = pipeline_size;
self
}
pub fn with_keep_alive(mut self, keep_alive: bool) -> Self {
self.keep_alive = keep_alive;
self
}
pub fn with_keep_alive_duration(mut self, keep_alive_duration: Duration) -> Self {
self.keep_alive_duration = keep_alive_duration;
self
}
pub fn with_ping_interval(mut self, ping_interval: Duration) -> Self {
self.ping_interval = ping_interval;
self
}
pub fn with_shutdown<S>(mut self, shutdown: S) -> StreamServerConfig
where
S: Future<Output = ()> + Send + 'static,
{
self.shutdown_signal = shutdown.boxed().shared();
self
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum StreamMsg<S> {
Str(S),
Ping,
Pong,
}
pub async fn serve_stream_sink<E, T, S>(
rpc: &MetaIoHandler<T>,
mut sink: impl Sink<StreamMsg<S>, Error = E> + Unpin,
stream: impl Stream<Item = Result<StreamMsg<S>, E>> + Unpin,
config: StreamServerConfig,
) -> Result<(), E>
where
T: Metadata + From<Session>,
S: From<String> + Deref<Target = str>,
{
static SESSION_ID: AtomicU64 = AtomicU64::new(0);
let (tx, mut rx) = channel(config.channel_size);
let session = Session {
id: SESSION_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
raw_tx: tx,
};
let dead_timer = tokio::time::sleep(config.keep_alive_duration);
tokio::pin!(dead_timer);
let mut ping_interval = tokio::time::interval(config.ping_interval);
ping_interval.reset();
ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
let mut result_stream = stream
.map(|message_or_err| async {
let msg = message_or_err?;
match msg {
StreamMsg::Str(msg) => Ok(rpc
.handle_request(&msg, session.clone().into())
.await
.map(|res| StreamMsg::Str(res.into()))),
StreamMsg::Ping => Ok(Some(StreamMsg::Pong)),
StreamMsg::Pong => Ok(None),
}
})
.buffer_unordered(config.pipeline_size);
let mut shutdown = config.shutdown_signal;
loop {
tokio::select! {
biased;
result = result_stream.next() => {
match result {
Some(result) => {
if let Some(s) = result? {
sink.send(s).await?;
}
if config.keep_alive {
dead_timer
.as_mut()
.reset(Instant::now() + config.keep_alive_duration);
}
}
None => {
break;
}
}
}
Some(msg) = rx.recv() => {
sink.send(StreamMsg::Str(msg.into())).await?;
}
_ = ping_interval.tick(), if config.keep_alive => {
sink.send(StreamMsg::Ping).await?;
}
_ = &mut dead_timer, if config.keep_alive => {
break;
}
_ = &mut shutdown => {
break;
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::stream;
#[tokio::test]
async fn test_ping_pong() {
let rpc = MetaIoHandler::<Session>::default();
let stream = stream::iter([Ok(StreamMsg::Ping)]);
let mut sink: Vec<StreamMsg<String>> = Vec::new();
let result = serve_stream_sink(&rpc, &mut sink, stream, Default::default()).await;
assert!(result.is_ok());
assert_eq!(sink, [StreamMsg::Pong]);
}
#[tokio::test]
async fn test_subscription() {
let mut rpc = MetaIoHandler::default();
rpc.add_method_with_meta("subscribe", |_params, session: Session| async move {
session
.raw_tx
.send("subscription_data".to_string())
.await
.unwrap();
Ok(serde_json::Value::String("ok".to_string()))
});
let stream = async_stream::stream! {
yield Ok(StreamMsg::Str(
r#"{"jsonrpc":"2.0","method":"subscribe","params":[],"id":1}"#.to_string(),
));
tokio::time::sleep(Duration::from_secs(1)).await;
};
tokio::pin!(stream);
let mut sink: Vec<StreamMsg<String>> = Vec::new();
let result = serve_stream_sink(&rpc, &mut sink, stream, Default::default()).await;
assert!(result.is_ok());
assert_eq!(sink.len(), 2);
assert_eq!(
serde_json::from_str::<serde_json::Value>(match &sink[0] {
StreamMsg::Str(s) => s,
_ => panic!("Expected StreamMsg::Str, got: {:?}", sink[0]),
})
.unwrap(),
serde_json::json!({ "jsonrpc": "2.0", "result": "ok", "id": 1 })
);
assert_eq!(sink[1], StreamMsg::Str("subscription_data".to_string()));
}
}