use std::{collections::HashMap, sync::Arc};
use cs_trace::{Tracer, child};
use cs_utils::{random_number, futures::wait};
use jsonrpc_core::{Result, IoDelegate, BoxFuture};
use jsonrpc_derive::rpc;
use tokio::{io::duplex, sync::{Mutex, mpsc}};
pub use crate::errors::RpcErrorCodes;
mod rpc_channel_record;
pub use rpc_channel_record::RpcChannelRecord;
pub use gen_client::Client as RpcChannelsServiceClient;
use super::rpc_channel::RemoteRpcChannelEvent;
#[rpc]
pub trait RpcChannelsServiceDefinition {
#[rpc(name = "createChannel")]
fn create_channel(&self, label: String) -> BoxFuture<Result<u16>>;
#[rpc(name = "removeChannel")]
fn remove_channel(&self, channel_id: u16) -> BoxFuture<Result<()>>;
#[rpc(name = "sendChannelData")]
fn send_channel_data(&self, channel_id: u16, data: Vec<u8>) -> BoxFuture<Result<usize>>;
}
pub struct RpcChannelsService {
trace: Box<dyn Tracer>,
channels: Arc<Mutex<HashMap<u16, RpcChannelRecord>>>,
on_data_channel_sink: Arc<Mutex<mpsc::Sender<RemoteRpcChannelEvent>>>,
}
impl RpcChannelsService {
pub fn new(
trace: &Box<dyn Tracer>,
channels: Arc<Mutex<HashMap<u16, RpcChannelRecord>>>,
on_data_channel_sink: mpsc::Sender<RemoteRpcChannelEvent>,
) -> Self {
let on_data_channel_sink = Arc::new(Mutex::new(on_data_channel_sink));
let trace = child!(trace, "channels-service");
return RpcChannelsService {
trace,
channels,
on_data_channel_sink,
}
}
pub fn as_delegate(self) -> IoDelegate<Self, ()> {
return self.to_delegate();
}
}
pub async fn setup_channel(
trace: &Box<dyn Tracer>,
channel_id: Option<u16>,
label: String,
channels: Arc<Mutex<HashMap<u16, RpcChannelRecord>>>,
) -> Result<RemoteRpcChannelEvent> {
let (source, sink) = duplex(4096);
let channel_id = {
let channel_record = RpcChannelRecord::new(Box::pin(source));
let channel_id = match channel_id {
Some(id) => id,
None => {
loop {
let channel_id = random_number(0..=u16::MAX);
if !channels.lock().await.contains_key(&channel_id) {
break channel_id;
}
wait(5).await;
}
},
};
let existing_channel = channels
.lock().await
.insert(channel_id, channel_record);
trace.debug(
&format!("channel record created: \"{}\"", &channel_id),
);
trace.debug(
&format!("contains channel: \"{:?}\"", &channels.lock().await.contains_key(&channel_id)),
);
assert!(
existing_channel.is_none(),
"Overrident existent channel.",
);
channel_id
};
{
let rpc_channel = RemoteRpcChannelEvent::new(
channel_id,
label,
sink,
);
return Ok(rpc_channel);
}
}
impl RpcChannelsServiceDefinition for RpcChannelsService {
fn create_channel(&self, label: String) -> BoxFuture<Result<u16>> {
let channels = Arc::clone(&self.channels);
let on_data_channel_sink = Arc::clone(&self.on_data_channel_sink);
let trace = &self.trace;
let create_channel = child!(trace, "create_channel");
return Box::pin(async move {
let rpc_channel = setup_channel(&create_channel, None, label, channels).await?;
create_channel.debug(
&format!("created rpc channel receiver \"{}\"", rpc_channel.label()),
);
let channel_id = rpc_channel.id();
let result = on_data_channel_sink
.lock().await
.send(rpc_channel).await;
match result {
Err(_error) => {
return Err(
RpcErrorCodes::ChannelCreationFailed.into(),
);
},
Ok(_) => {},
}
return Ok(channel_id);
});
}
fn remove_channel(&self, channel_id: u16) -> BoxFuture<Result<()>> {
let channels = Arc::clone(&self.channels);
return Box::pin(async move {
let result = channels
.lock().await
.remove(&channel_id);
return match result {
Some(_) => Ok(()),
None => Err(
RpcErrorCodes::ChannelRemovalFailed.into(),
),
};
});
}
fn send_channel_data(&self, channel_id: u16, data: Vec<u8>) -> BoxFuture<Result<usize>> {
let channels = Arc::clone(&self.channels);
let trace = &self.trace;
let send_data_trace = child!(trace, "send-data");
return Box::pin(async move {
let mut guard = channels.lock().await;
let available_channels: Vec<String> = guard.keys().map(|key| {
return key.to_string();
}).collect();
send_data_trace.debug(
&format!("available channels: {:?}", &available_channels),
);
return match guard.get_mut(&channel_id) {
None => {
Err(
RpcErrorCodes::NoChannelFound.into(),
)
},
Some(channel) => channel.send_data(&data[..]).await,
};
});
}
}
#[cfg(test)]
mod tests {
use std::{pin::Pin, sync::Arc, collections::HashMap};
use cs_trace::create_trace;
use futures::Future;
use tokio::sync::{mpsc::{self, Receiver}, Mutex};
use crate::{multiplexed_connection::rpc::{RemoteRpcChannelEvent, RpcChannelsService}, utils::AsyncDuplexRpcServer};
pub use super::RpcChannelsServiceClient;
fn setup() -> (RpcChannelsServiceClient, Receiver<RemoteRpcChannelEvent>, Pin<Box<dyn Future<Output = ()> + Send>>) {
use cs_utils::futures::GenericCodec;
use futures::StreamExt;
use jsonrpc_core::IoHandler;
use jsonrpc_core_client::transports;
use tokio_util::codec::Framed;
use tokio::try_join;
let trace = create_trace!("setup");
let (on_data_channel_sink, on_data_channel_source) = mpsc::channel(10);
let (server_duplex, client_duplex) = tokio::io::duplex(10);
let mut server_io = IoHandler::new();
server_io.extend_with(
RpcChannelsService::new(
&trace,
Arc::new(Mutex::new(HashMap::new())),
on_data_channel_sink,
).as_delegate(),
);
let _server = AsyncDuplexRpcServer::new(server_io).build(server_duplex);
let framed_remote_stream = Framed::new(
client_duplex,
GenericCodec::<String>::new(),
);
let (sink, mut source) = framed_remote_stream.split();
let stream = async_stream::stream! {
while let Some(message) = source.next().await {
let message = message.unwrap();
yield message;
}
};
let (rpc_client, sender) = transports::duplex(Box::pin(sink), Box::pin(stream));
let client = RpcChannelsServiceClient::from(sender);
let fut: Pin<Box<dyn Future<Output = ()> + Send>> = Box::pin(async move {
try_join!(
tokio::spawn(async move {
rpc_client.await.unwrap();
}),
).unwrap();
});
return (client, on_data_channel_source, fut);
}
mod create_channel {
use cs_utils::random_str;
use rstest::rstest;
use tokio::try_join;
pub use super::RpcChannelsServiceClient;
use super::setup;
#[rstest]
#[case::test_number(0)]
#[case::test_number(1)]
#[case::test_number(2)]
#[tokio::test]
async fn creates_channel(
#[case] _test_number: usize,
) {
use cs_utils::futures::wait;
let (
client,
mut on_data_channel_source,
fut,
) = setup();
let ((channel_id, channel_label), rpc_channel, _) = try_join!(
tokio::spawn(async move {
let channel_label = format!("rpc-channel-{}", random_str(4));
let channel_id = client
.create_channel(channel_label.clone()).await
.expect("Cannot create channel.");
(channel_id, channel_label)
}),
tokio::spawn(async move {
let channel = loop {
let channel = match on_data_channel_source.try_recv() {
Ok(ch) => ch,
Err(_err) => {
wait(50).await;
continue;
},
};
break channel;
};
channel
}),
tokio::spawn(fut),
).unwrap();
assert_eq!(
channel_id,
rpc_channel.id(),
"Must create channel with passed channel id.",
);
assert_eq!(
&channel_label,
rpc_channel.label(),
"Must create channel with passed label.",
);
}
}
mod remove_channel {
use cs_utils::random_str;
use rstest::rstest;
use tokio::try_join;
use cs_utils::{futures::wait, random_number};
pub use super::RpcChannelsServiceClient;
use super::setup;
#[rstest]
#[case::test_number(0)]
#[case::test_number(1)]
#[case::test_number(2)]
#[tokio::test]
async fn removes_channel(
#[case] _test_number: usize,
) {
let (
client,
_on_data_channel_source,
fut,
) = setup();
try_join!(
tokio::spawn(async move {
let channel_label = format!("rpc-channel-{}", random_str(4));
let channel_id = client
.create_channel(channel_label.clone()).await
.expect("Cannot create channel.");
wait(random_number(50..100)).await;
client
.remove_channel(channel_id).await
.expect("Cannot remove channel.");
wait(random_number(50..100)).await;
let result = client
.remove_channel(channel_id).await;
assert!(
result.is_err(),
"Must return an error.",
);
let result = client
.send_channel_data(channel_id, vec![1, 2, 3]).await;
assert!(
result.is_err(),
"Must return an error.",
);
}),
tokio::spawn(fut),
).unwrap();
}
}
mod send_channel_data {
pub use super::RpcChannelsServiceClient;
use super::setup;
use cs_utils::{random_str, futures::wait};
use tokio::try_join;
use rstest::rstest;
#[rstest]
#[case::size_8_32(8, 32)]
#[case::size_128_512(128, 512)]
#[case::size_2048_4096(2048, 4096)]
#[case::size_4096_8192(4096, 8192)]
#[case::size_8192_16384(8192, 16384)]
#[tokio::test]
async fn sends_channel_data(
#[case] str_min_size: usize,
#[case] str_max_size: usize,
) {
use cs_utils::{random_str_rg, random_number};
let (
client,
mut on_data_channel_source,
fut,
) = setup();
let test_data = vec![
random_str_rg(str_min_size..=str_max_size),
random_str_rg(str_min_size..=str_max_size),
random_str_rg(str_min_size..=str_max_size),
random_str_rg(str_min_size..=str_max_size),
random_str_rg(str_min_size..=str_max_size),
random_str_rg(str_min_size..=str_max_size),
random_str_rg(str_min_size..=str_max_size),
].join("");
let data_to_send = test_data.clone();
try_join!(
tokio::spawn(async move {
let channel_label = format!("rpc-channel-{}", random_str(4));
let channel_id = client
.create_channel(channel_label.clone()).await
.expect("Cannot create channel.");
let mut i = 0;
let data = data_to_send.as_bytes().to_vec();
while i < data_to_send.len() {
let message_len = random_number(str_min_size..=str_max_size) / 2;
let message_len = if i + message_len < data.len() {
i + message_len
} else {
data.len()
};
let bytes_sent = client
.send_channel_data(channel_id, (&data[i..message_len]).to_vec()).await
.expect("Cannot send channel data.");
assert!(
bytes_sent > 0,
"No bytes sent.",
);
i += bytes_sent as usize;
}
}),
tokio::spawn(async move {
let channel = loop {
let channel = match on_data_channel_source.try_recv() {
Ok(ch) => ch,
Err(_err) => {
wait(50).await;
continue;
},
};
break channel;
};
crate::multiplexed_connection::rpc::tests::assert_rpc_channel_receiver_receives_data(channel, test_data).await;
}),
tokio::spawn(fut),
).unwrap();
}
}
}