use std::{
collections::HashMap,
convert::Infallible,
env,
ffi::OsString,
fmt::{self, Debug, Display},
future::Future,
io,
path::Path,
pin::Pin,
str::FromStr,
sync::Arc,
task::{Context, Poll},
};
use futures::{Stream, StreamExt};
use ipc_channel::{
ipc::{IpcReceiver, IpcSender},
IpcError, TryRecvError,
};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use thiserror::Error;
use tokio::{
sync::{mpsc, watch},
time::{Duration, Instant},
};
use uuid::Uuid;
#[cfg(feature = "message-schema-validation")]
use schemars::{schema_for, JsonSchema, Schema};
mod client;
pub use client::*;
mod server;
pub use server::*;
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct ConnectionKey(String);
impl From<String> for ConnectionKey {
fn from(s: String) -> Self {
Self(s)
}
}
impl FromStr for ConnectionKey {
type Err = Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self(s.to_string()))
}
}
impl Display for ConnectionKey {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
write!(f, "{}", self.0)
}
}
impl From<ConnectionKey> for OsString {
fn from(s: ConnectionKey) -> Self {
OsString::from(s.0)
}
}
impl From<ConnectionKey> for String {
fn from(key: ConnectionKey) -> Self {
key.0
}
}
type PendingReplyEntry<U> = (
Uuid,
(
mpsc::UnboundedSender<Result<InternalMessageKind<U>, IpcRpcError>>,
Instant,
),
);
#[derive(Deserialize, Serialize, Debug, Clone)]
#[serde(bound(deserialize = ""))]
struct InternalMessage<U: UserMessage> {
uuid: uuid::Uuid,
kind: InternalMessageKind<U>,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
#[serde(bound(deserialize = ""))]
enum InternalMessageKind<U: UserMessage> {
InitConnection(IpcSender<InternalMessage<U>>),
Hangup,
UserMessage(U),
UserMessageSchema(String),
UserMessageSchemaOk,
UserMessageSchemaError { other_schema: String },
}
#[derive(Clone, Debug, Error)]
pub enum IpcRpcError {
#[error("io error")]
IoError(#[from] Arc<io::Error>),
#[error("internal ipc channel error")]
IpcChannelError(#[from] Arc<IpcError>),
#[error("connection initialization timed out")]
ConnectTimeout,
#[error("connection established, but initial handshake was not performed properly")]
HandshakeFailure,
#[error("client already connected")]
ClientAlreadyConnected,
#[error("peer disconnected")]
Disconnected,
#[error("time out while waiting for a reply")]
ReplyTimeout,
#[error("connection dropped pre-emptively")]
ConnectionDropped,
}
impl From<io::Error> for IpcRpcError {
fn from(e: io::Error) -> Self {
Self::IoError(Arc::new(e))
}
}
impl From<IpcError> for IpcRpcError {
fn from(e: IpcError) -> Self {
Self::IpcChannelError(Arc::new(e))
}
}
pub const DEFAULT_REPLY_TIMEOUT: Duration = Duration::from_secs(5);
async fn process_incoming_mail<
Fut: Future<Output = Option<U>> + Send,
F: Fn(U) -> Fut + Send + Sync + 'static,
U: UserMessage,
>(
is_server: bool,
mut pending_reply_receiver: mpsc::UnboundedReceiver<PendingReplyEntry<U>>,
mut receiver: IpcReceiveStream<InternalMessage<U>>,
message_handler: F,
response_sender: IpcSender<InternalMessage<U>>,
status_sender: watch::Sender<ConnectionStatus>,
) {
let mut pending_replies = HashMap::<
Uuid,
(
mpsc::UnboundedSender<Result<InternalMessageKind<U>, IpcRpcError>>,
Instant,
),
>::new();
let message_handler = Arc::new(message_handler);
let log_prefix = get_log_prefix(is_server);
log::info!("{}Processing incoming mail!", log_prefix);
let mut consecutive_error_count = 0;
let mut pending_reply_scheduled_time = Option::<Instant>::None;
let add_pending_reply = |pending_reply_scheduled_time: &mut Option<Instant>,
pending_replies: &mut HashMap<_, _>,
(key, (sender, timeout))| {
*pending_reply_scheduled_time = Some(
pending_reply_scheduled_time
.map(|t| t.min(timeout))
.unwrap_or(timeout),
);
pending_replies.insert(key, (sender, timeout));
};
loop {
while let Ok(pending_reply) = pending_reply_receiver.try_recv() {
add_pending_reply(
&mut pending_reply_scheduled_time,
&mut pending_replies,
pending_reply,
);
}
tokio::select! {
true = async { if let Some(t) = pending_reply_scheduled_time { tokio::time::sleep_until(t).await; true } else { false } } => {
pending_replies.retain(|_k, v| {
let keep = v.1 > Instant::now();
if !keep {
let _ = v.0.send(Err(IpcRpcError::ReplyTimeout));
}
keep
});
pending_reply_scheduled_time = pending_replies.values().map(|i| i.1).min();
}
pending_reply = pending_reply_receiver.recv() => {
match pending_reply {
None => {
break;
}
Some(pending_reply) => {
add_pending_reply(&mut pending_reply_scheduled_time, &mut pending_replies, pending_reply);
}
}
},
r = receiver.next() => {
match r {
None => {
break;
}
Some(Err(e)) => {
if let IpcError::Disconnected = e {
log::info!("{}Peer disconnected.", log_prefix);
break;
} else {
log::error!("{}Error receiving message from peer {:?}", log_prefix, e);
consecutive_error_count += 1;
if consecutive_error_count > 20 {
log::error!("{}Too many consecutive errors, shutting down.", log_prefix);
break;
}
}
}
Some(Ok(message)) => {
consecutive_error_count = 0;
log::debug!("{}Got message! {:?}", log_prefix, message);
let reply = pending_replies.remove(&message.uuid);
if let Some((reply_drop_box, _)) = reply {
log::debug!("{}It's a reply, forwarding!", log_prefix);
let _ = reply_drop_box.send(Ok(message.kind));
} else {
log::debug!("{}It's not a reply, handling!", log_prefix);
let message_uuid = message.uuid;
match message.kind {
InternalMessageKind::UserMessage(user_message) => {
let message_handler = Arc::clone(&message_handler);
let response_sender = response_sender.clone();
tokio::spawn(async move {
if let Some(m) = message_handler(user_message).await {
let r = response_sender.send(InternalMessage {
uuid: message_uuid,
kind: InternalMessageKind::UserMessage(m),
});
if let Err(e) = r {
log::error!("Failed to send reply {e:?}");
}
}
});
}
#[cfg(feature = "message-schema-validation")]
InternalMessageKind::UserMessageSchema(other_schema) => {
let my_schema = schema_for!(U);
let kind = match serde_json::from_str::<Schema>(&other_schema) {
Ok(other_schema) => {
if other_schema == my_schema {
InternalMessageKind::UserMessageSchemaOk
} else {
InternalMessageKind::UserMessageSchemaError {
other_schema: serde_json::to_string(&my_schema).expect("upstream guarantees this won't fail")
}
}
},
Err(_) => {
log::error!("Failed to deserialize incoming schema properly, got {other_schema:?}");
InternalMessageKind::UserMessageSchemaError {
other_schema: serde_json::to_string(&my_schema).expect("upstream guarantees this won't fail")
}
}
};
let r = response_sender.send(InternalMessage {
uuid: message_uuid,
kind,
});
if let Err(e) = r {
log::error!("Failed to send validation response {e:#?}");
}
}
InternalMessageKind::Hangup => {
break;
}
_ => {}
}
}
}
}
}
}
}
let _ = status_sender.send(ConnectionStatus::DisconnectedCleanly);
}
fn get_log_prefix(is_server: bool) -> String {
let first_arg = env::args()
.next()
.unwrap_or_else(|| String::from("Unknown"));
let process = Path::new(&first_arg)
.file_name()
.unwrap_or_else(|| "Unknown".as_ref())
.to_string_lossy();
if is_server {
format!("{} as Server: ", process)
} else {
format!("{} as Client: ", process)
}
}
#[derive(Clone, Debug)]
pub enum ConnectionStatus {
WaitingForClient,
Connected,
DisconnectedCleanly,
DisconnectError(IpcRpcError),
}
impl ConnectionStatus {
pub fn session_end_result(&self) -> Option<Result<(), IpcRpcError>> {
match self {
ConnectionStatus::WaitingForClient | ConnectionStatus::Connected => None,
ConnectionStatus::DisconnectedCleanly => Some(Ok(())),
ConnectionStatus::DisconnectError(e) => Some(Err(e.clone())),
}
}
}
struct IpcReplyFuture<U: UserMessage> {
receiver: mpsc::UnboundedReceiver<Result<InternalMessageKind<U>, IpcRpcError>>,
}
impl<U: UserMessage> Future for IpcReplyFuture<U> {
type Output = Result<InternalMessageKind<U>, IpcRpcError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.receiver.poll_recv(cx).map(|o| match o {
Some(m) => m,
None => Err(IpcRpcError::ConnectionDropped),
})
}
}
struct IpcReceiveStream<T> {
receiver: mpsc::UnboundedReceiver<Result<T, IpcError>>,
}
impl<T> IpcReceiveStream<T>
where
T: Send + for<'de> Deserialize<'de> + Serialize + 'static,
{
pub fn new(ipc_receiver: IpcReceiver<T>) -> Self {
let (sender, receiver) = mpsc::unbounded_channel();
tokio::task::spawn_blocking(move || loop {
match ipc_receiver.try_recv_timeout(Duration::from_millis(250)) {
Ok(msg) => {
if sender.send(Ok(msg)).is_err() {
break;
}
}
Err(TryRecvError::IpcError(e)) => {
if sender.send(Err(e)).is_err() {
break;
}
}
Err(TryRecvError::Empty) => {
if sender.is_closed() {
break;
}
}
}
});
Self { receiver }
}
}
impl<T> Stream for IpcReceiveStream<T> {
type Item = Result<T, IpcError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.receiver.poll_recv(cx)
}
}
#[derive(Debug, Clone)]
pub enum SchemaValidationStatus {
ValidationDisabledAtCompileTime,
ValidationNotPerformedProperly,
ValidationCommunicationFailed(IpcRpcError),
SchemasMatched,
SchemaMismatch {
our_schema: String,
their_schema: String,
},
}
impl SchemaValidationStatus {
pub fn is_success(&self) -> bool {
matches!(self, SchemaValidationStatus::SchemasMatched)
}
pub fn assert_success(&self) {
if !self.is_success() {
panic!("ipc-rpc user message schema failed to validate, error {self:#?}");
}
}
}
#[cfg(feature = "message-schema-validation")]
pub trait UserMessage:
'static + Send + Debug + Clone + DeserializeOwned + Serialize + JsonSchema
{
}
#[cfg(feature = "message-schema-validation")]
impl<T> UserMessage for T where
T: 'static + Send + Debug + Clone + DeserializeOwned + Serialize + JsonSchema
{
}
#[cfg(not(feature = "message-schema-validation"))]
pub trait UserMessage: 'static + Send + Debug + Clone + DeserializeOwned + Serialize {}
#[cfg(not(feature = "message-schema-validation"))]
impl<T> UserMessage for T where T: 'static + Send + Debug + Clone + DeserializeOwned + Serialize {}
#[macro_export]
macro_rules! rpc_call {
(sender: $sender:expr, to_send: $to_send:expr, receiver: $received:pat_param => $to_do:block,) => {
$sender.send($to_send).await.map(|m| match m {
$received => $to_do,
_ => panic!("rpc_call response didn't match given pattern"),
})
};
}
#[cfg(test)]
mod tests {
use tokio::time::timeout;
use super::*;
#[cfg(not(feature = "message-schema-validation"))]
compile_error!("Tests must be executed with all features on");
#[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)]
pub struct IpcProtocolMessage {
pub kind: IpcProtocolMessageKind,
}
#[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)]
pub enum IpcProtocolMessageKind {
TestMessage,
ClientTestReply,
ServerTestReply,
}
#[test_log::test(tokio::test(flavor = "multi_thread", worker_threads = 3))]
async fn basic_dialogue() {
let (server_key, mut server) =
server::IpcRpcServer::initialize_server(|message: IpcProtocolMessage| async move {
match message.kind {
IpcProtocolMessageKind::TestMessage => Some(IpcProtocolMessage {
kind: IpcProtocolMessageKind::ServerTestReply,
}),
_ => None,
}
})
.await
.unwrap();
let mut client = client::IpcRpcClient::initialize_client(
server_key,
|message: IpcProtocolMessage| async move {
match message.kind {
IpcProtocolMessageKind::TestMessage => Some(IpcProtocolMessage {
kind: IpcProtocolMessageKind::ClientTestReply,
}),
_ => None,
}
},
)
.await
.unwrap();
server.schema_validated().await.unwrap().assert_success();
client.schema_validated().await.unwrap().assert_success();
let client_reply = server
.send(IpcProtocolMessage {
kind: IpcProtocolMessageKind::TestMessage,
})
.await;
if !matches!(
client_reply.as_ref().map(|r| &r.kind),
Ok(IpcProtocolMessageKind::ClientTestReply)
) {
panic!("client reply was of unexpected type: {:?}", client_reply);
}
let server_reply = client
.send(IpcProtocolMessage {
kind: IpcProtocolMessageKind::TestMessage,
})
.await;
if !matches!(
server_reply.as_ref().map(|r| &r.kind),
Ok(IpcProtocolMessageKind::ServerTestReply)
) {
panic!("server reply was of unexpected type: {:?}", server_reply);
}
}
#[test_log::test(tokio::test(flavor = "multi_thread", worker_threads = 3))]
async fn send_without_await() {
let (server_success_sender, mut server_success_receiver) = mpsc::unbounded_channel();
let (client_success_sender, mut client_success_receiver) = mpsc::unbounded_channel();
let (server_key, mut server) =
server::IpcRpcServer::initialize_server(move |message: IpcProtocolMessage| {
let server_success_sender = server_success_sender.clone();
async move {
match message.kind {
IpcProtocolMessageKind::TestMessage => {
server_success_sender.send(()).unwrap()
}
_ => {}
}
None
}
})
.await
.unwrap();
let mut client = client::IpcRpcClient::initialize_client(
server_key,
move |message: IpcProtocolMessage| {
let client_success_sender = client_success_sender.clone();
async move {
match message.kind {
IpcProtocolMessageKind::TestMessage => {
client_success_sender.send(()).unwrap()
}
_ => {}
}
None
}
},
)
.await
.unwrap();
server.schema_validated().await.unwrap().assert_success();
client.schema_validated().await.unwrap().assert_success();
let _ = server.send(IpcProtocolMessage {
kind: IpcProtocolMessageKind::TestMessage,
});
let _ = client.send(IpcProtocolMessage {
kind: IpcProtocolMessageKind::TestMessage,
});
assert_eq!(
timeout(Duration::from_secs(3), server_success_receiver.recv()).await,
Ok(Some(()))
);
assert_eq!(
timeout(Duration::from_secs(3), client_success_receiver.recv()).await,
Ok(Some(()))
);
}
#[test_log::test(tokio::test(flavor = "multi_thread", worker_threads = 3))]
async fn timeout_test() {
let (server_key, mut server) =
server::IpcRpcServer::initialize_server(|message: IpcProtocolMessage| async move {
match message.kind {
_ => None,
}
})
.await
.unwrap();
let mut client = client::IpcRpcClient::initialize_client(
server_key,
|message: IpcProtocolMessage| async move {
match message.kind {
_ => None,
}
},
)
.await
.unwrap();
server.schema_validated().await.unwrap().assert_success();
client.schema_validated().await.unwrap().assert_success();
let wait_start = Instant::now();
let client_reply = server
.send(IpcProtocolMessage {
kind: IpcProtocolMessageKind::TestMessage,
})
.await;
assert!(wait_start.elapsed() >= DEFAULT_REPLY_TIMEOUT);
if !matches!(client_reply, Err(IpcRpcError::ReplyTimeout)) {
panic!("client reply was of unexpected type: {:?}", client_reply);
}
let wait_start = Instant::now();
let server_reply = client
.send(IpcProtocolMessage {
kind: IpcProtocolMessageKind::TestMessage,
})
.await;
assert!(wait_start.elapsed() >= DEFAULT_REPLY_TIMEOUT);
if !matches!(server_reply, Err(IpcRpcError::ReplyTimeout)) {
panic!("server reply was of unexpected type: {:?}", server_reply);
}
}
#[test_log::test(tokio::test(flavor = "multi_thread", worker_threads = 3))]
async fn custom_timeout_test() {
let (server_key, mut server) =
server::IpcRpcServer::initialize_server(|message: IpcProtocolMessage| async move {
match message.kind {
_ => None,
}
})
.await
.unwrap();
let mut client = client::IpcRpcClient::initialize_client(
server_key,
|message: IpcProtocolMessage| async move {
match message.kind {
_ => None,
}
},
)
.await
.unwrap();
server.schema_validated().await.unwrap().assert_success();
client.schema_validated().await.unwrap().assert_success();
let custom_timeout: Duration = DEFAULT_REPLY_TIMEOUT / 2;
let wait_start = Instant::now();
let client_reply = server
.send_timeout(
IpcProtocolMessage {
kind: IpcProtocolMessageKind::TestMessage,
},
custom_timeout,
)
.await;
assert!(wait_start.elapsed() >= custom_timeout);
assert!(wait_start.elapsed() < DEFAULT_REPLY_TIMEOUT);
if !matches!(client_reply, Err(IpcRpcError::ReplyTimeout)) {
panic!("client reply was of unexpected type: {:?}", client_reply);
}
let wait_start = Instant::now();
let server_reply = client
.send_timeout(
IpcProtocolMessage {
kind: IpcProtocolMessageKind::TestMessage,
},
custom_timeout,
)
.await;
assert!(wait_start.elapsed() >= custom_timeout);
assert!(wait_start.elapsed() < DEFAULT_REPLY_TIMEOUT);
if !matches!(server_reply, Err(IpcRpcError::ReplyTimeout)) {
panic!("server reply was of unexpected type: {:?}", server_reply);
}
}
#[test_log::test]
fn server_drop_does_not_hang() {
let thread = std::thread::spawn(|| {
let runtime = tokio::runtime::Runtime::new().unwrap();
runtime.block_on(async {
let (_server_key, _server) = server::IpcRpcServer::initialize_server(
|message: IpcProtocolMessage| async move {
match message.kind {
_ => None,
}
},
)
.await
.unwrap();
})
});
let start = Instant::now();
let timeout = Duration::from_secs(5);
while !thread.is_finished() {
if start.elapsed() >= timeout {
std::process::exit(1);
}
}
}
#[test_log::test]
fn server_disconnect_test() {
let drop_detector = Arc::new(());
let drop_detector_clone = drop_detector.clone();
let runtime = tokio::runtime::Runtime::new().unwrap();
runtime.block_on(async move {
let (server_key, server) = server::IpcRpcServer::initialize_server({
let drop_detector_clone = drop_detector_clone.clone();
move |message: IpcProtocolMessage| {
let drop_detector_clone = drop_detector_clone.clone();
async move {
match message.kind {
_ => {
let _ = drop_detector_clone.clone();
None
}
}
}
}
})
.await
.unwrap();
let client = client::IpcRpcClient::initialize_client(
server_key,
|message: IpcProtocolMessage| async move {
match message.kind {
_ => None,
}
},
)
.await
.unwrap();
assert_eq!(Arc::strong_count(&drop_detector_clone), 3);
drop(server);
client.wait_for_server_to_disconnect().await.unwrap();
});
let start_shutdown = Instant::now();
runtime.shutdown_timeout(Duration::from_secs(5));
assert!(start_shutdown.elapsed() < Duration::from_secs(3));
assert_eq!(Arc::strong_count(&drop_detector), 1);
}
#[test_log::test]
fn client_disconnect_test() {
use std::time::Instant;
let drop_detector = Arc::new(());
let drop_detector_clone = drop_detector.clone();
let runtime = tokio::runtime::Runtime::new().unwrap();
runtime.block_on(async move {
let (server_key, mut server) =
server::IpcRpcServer::initialize_server(|message: IpcProtocolMessage| async move {
match message.kind {
_ => None,
}
})
.await
.unwrap();
let client = client::IpcRpcClient::initialize_client(server_key, {
let drop_detector_clone = drop_detector_clone.clone();
move |message: IpcProtocolMessage| {
let _drop_detector_clone = drop_detector_clone.clone();
async move {
match message.kind {
_ => None,
}
}
}
})
.await
.unwrap();
assert_eq!(Arc::strong_count(&drop_detector_clone), 3);
drop(client);
server.wait_for_client_to_disconnect().await.unwrap();
});
let start_shutdown = Instant::now();
runtime.shutdown_timeout(Duration::from_secs(5));
assert!(start_shutdown.elapsed() < Duration::from_secs(3));
assert_eq!(Arc::strong_count(&drop_detector), 1);
}
#[test_log::test(tokio::test(flavor = "multi_thread", worker_threads = 3))]
async fn rpc_call_macro_test() {
let (server_key, mut server) =
server::IpcRpcServer::initialize_server(|message: IpcProtocolMessage| async move {
match message.kind {
_ => Some(IpcProtocolMessage {
kind: IpcProtocolMessageKind::ServerTestReply,
}),
}
})
.await
.unwrap();
let mut client = client::IpcRpcClient::initialize_client(
server_key,
|message: IpcProtocolMessage| async move {
match message.kind {
_ => Some(IpcProtocolMessage {
kind: IpcProtocolMessageKind::ClientTestReply,
}),
}
},
)
.await
.unwrap();
server.schema_validated().await.unwrap().assert_success();
client.schema_validated().await.unwrap().assert_success();
rpc_call!(
sender: server,
to_send: IpcProtocolMessage {
kind: IpcProtocolMessageKind::TestMessage
},
receiver: IpcProtocolMessage {
kind: IpcProtocolMessageKind::ClientTestReply
} => {
},
)
.unwrap();
rpc_call!(
sender: client,
to_send: IpcProtocolMessage {
kind: IpcProtocolMessageKind::TestMessage
},
receiver: IpcProtocolMessage {
kind: IpcProtocolMessageKind::ServerTestReply
} => {
},
)
.unwrap();
}
}